File size: 9,247 Bytes
b3a3b15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import torch

from tests.test_select_block import create_block, Config, SparseConfig
import csv
import time
import torch
import torch.nn as nn
from flash_attn.utils.generation import InferenceParams
from HybridTensor.utils.utils import arg_parser, _get_device, sparse_index, generate_random_BH_index, get_gpu_name
from HybridTensor.utils.profiling import cuda_profiler
import math
from tqdm import tqdm

def run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype):
    config = Config()
    sp_config = SparseConfig()
    sp_config.attn_topk = attn_topk
    
    config.hidden_size = args.in_features
    config.num_attention_heads = args.in_features // 128
    config.use_heuristic = False    # use pre-compiled heuristic or complie new one during runtime

    # Test create_block
    sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype)
    sparse_block.eval()
    sparse_block.mlp_topk = index_size
    
    regular_config = config
    regular_config.att_sparse = False
    regular_config.mlp_sparse = False
    regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype)
    regular_block.eval()
    
    # inference simulation with select block
    max_seqlen = seq_len + 16
    max_batch_size = batch_size
    in_features = args.in_features
    head_dim = 128
    
    inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size)
    process_group = None
    sequence_parallel = False
    
    # for testing and debugging
    heads = config.num_attention_heads
    selected_heads = heads // 2
    
    # Create a static index vector (length equals total columns in B).
    total_neurons = args.in_features * 4
    test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32)
    active_indices = sparse_index(args.index_size, total_neurons)[0]
    test_index_vec[:args.index_size] = active_indices
    if args.index_size < total_neurons:
        test_index_vec[args.index_size:] = 0  # Fill the rest with dummy values.
    
    # test_index_vec = sparse_index(args.in_features, args.in_features*4)[0].cuda()
    test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads)
    test_index_size = args.index_size
    
    mixer_kwargs = (
        {"seqlen": seq_len}
        if process_group is not None and sequence_parallel
        else {}
    )
    if inference_params is not None:
        mixer_kwargs["inference_params"] = inference_params
    
    with torch.no_grad():
        # prefill stage 
        original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16)
            
        # Test prefill
        output_sparse = sparse_block(original_seq, mixer_kwargs=mixer_kwargs)
        output_regular = regular_block(original_seq, mixer_kwargs=mixer_kwargs)
        
        # need to update inference_params to reflect the new sequence length
        mixer_kwargs["inference_params"].seqlen_offset = seq_len
        
        # Decode stage  
        input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16)
        
        out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs)
        
        mixer_kwargs["inference_params"].seqlen_offset = seq_len
        
        out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs)
        
        # mesure decode stage time in ms
        # print("Without CUDA Graphs")
        # out_decode_regular, regular_time = cuda_profiler(regular_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2)
        # print(f"Regular time: {regular_time} ms")
        
        # out_decode_sparse, sparse_time = cuda_profiler(sparse_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2)
        # print(f"Sparse time: {sparse_time} ms")
        
        # speedup = regular_time / sparse_time
        # print(f"Speedup: {speedup}")
        
        # --- CUDA Graph Capture for Decode Stage ---
        # Allocate static buffer for regular block (shape assumed fixed)
        input_x_static = input_x.clone()
        output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype)

        # Capture regular block graph
        _ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs)
        torch.cuda.synchronize()
        graph_regular = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph_regular):
            res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs)
            if isinstance(res, tuple):
                res = res[0]
            output_regular_static.copy_(res)

        # For the sparse block, run a dummy call to determine its output shape.
        # Also, reset the inference parameter to ensure consistent behavior.
        mixer_kwargs["inference_params"].seqlen_offset = seq_len
        temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs)
        if isinstance(temp, tuple):
            temp = temp[0]
        # print("Captured sparse block output shape:", temp.shape)
        # Allocate static buffer matching the dummy run's shape.
        output_sparse_static = torch.empty_like(temp)
        # print("output_sparse_static shape:", output_sparse_static.shape)
        torch.cuda.synchronize()
        
        mixer_kwargs["inference_params"].seqlen_offset = seq_len
        graph_sparse = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph_sparse):
            res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs)
            if isinstance(res, tuple):
                res = res[0]
            output_sparse_static.copy_(res)

        # Warmup CUDA Graph replays
        for _ in range(5):
            graph_regular.replay()
            graph_sparse.replay()
        torch.cuda.synchronize()

        # --- Measure CUDA Graph Replay Latency ---
        num_replays = 10

        start = time.time()
        for _ in range(num_replays):
            graph_regular.replay()
        torch.cuda.synchronize()
        regular_graph_time = (time.time() - start) * 1000 / num_replays

        start = time.time()
        for _ in range(num_replays):
            graph_sparse.replay()
        torch.cuda.synchronize()
        sparse_graph_time = (time.time() - start) * 1000 / num_replays
        speedup = regular_graph_time / sparse_graph_time
        # print()
        # print("With CUDA Graphs")
        # print(f"Regular block time (CUDA Graphs): {regular_graph_time} ms")
        # print(f"Sparse block time (CUDA Graphs): {sparse_graph_time} ms")
        # print(f"Speedup (CUDA Graphs): {speedup}")
        
    return regular_graph_time, sparse_graph_time, speedup

if __name__ == "__main__":
    
    args = arg_parser()
    device = _get_device(0)
    dtype = torch.float16
    gpu_name = get_gpu_name()

    # Parameter grids.
    # batch_sizes = [1, 4, 8, 16]
    # seq_lengths = [128, 512]
    # index_sizes = [512, 1024, 2048, 4096]
    # attn_topks = [0.3, 0.4, 0.5]

    batch_sizes = [1, 8, 16, 32]
    seq_lengths = [1024, 2048]
    # index_sizes = [512, 1024, 2048, 4096, 8192]
    index_size_p = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
    total_neurons = args.in_features * 4
    
    # Calculate initial index_size values
    index_sizes = [int(total_neurons * i) for i in index_size_p]

    # Round up to the nearest multiple of 128 if necessary
    index_sizes = [math.ceil(size / 128) * 128 if size % 128 != 0 else size for size in index_sizes]
    
    attn_topks = [0.3, 0.4, 0.5]

    # Calculate total number of simulations.
    total_runs = len(batch_sizes) * len(seq_lengths) * len(index_sizes) * len(attn_topks)
    output_file = f"results/simulations/{gpu_name}_select_block_{args.in_features}_inference_sim.csv"

    with open(output_file, mode='w', newline='') as csv_file:
        fieldnames = ["in_features", "batch_size", "seq_len", "index_size", "neuron_activation", "attn_topk",
                      "regular_graph_time_ms", "sparse_graph_time_ms", "speedup"]
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
        writer.writeheader()

        # Iterate over all combinations with tqdm progress bar.
        for batch_size in tqdm(batch_sizes, desc="Batch Sizes"):
            for seq_len in seq_lengths:
                for index_size in index_sizes:
                    for attn_topk in attn_topks:
                        reg_time, spa_time, speedup = run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype)
                        writer.writerow({
                            "in_features": args.in_features,
                            "batch_size": batch_size,
                            "seq_len": seq_len,
                            "index_size": index_size,
                            "neuron_activation": index_size / total_neurons,
                            "attn_topk": attn_topk,
                            "regular_graph_time_ms": reg_time,
                            "sparse_graph_time_ms": spa_time,
                            "speedup": speedup
                        })
                        csv_file.flush()
    print(f"Simulation complete. Results saved to {output_file}")