File size: 9,247 Bytes
b3a3b15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import torch
from tests.test_select_block import create_block, Config, SparseConfig
import csv
import time
import torch
import torch.nn as nn
from flash_attn.utils.generation import InferenceParams
from HybridTensor.utils.utils import arg_parser, _get_device, sparse_index, generate_random_BH_index, get_gpu_name
from HybridTensor.utils.profiling import cuda_profiler
import math
from tqdm import tqdm
def run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype):
config = Config()
sp_config = SparseConfig()
sp_config.attn_topk = attn_topk
config.hidden_size = args.in_features
config.num_attention_heads = args.in_features // 128
config.use_heuristic = False # use pre-compiled heuristic or complie new one during runtime
# Test create_block
sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype)
sparse_block.eval()
sparse_block.mlp_topk = index_size
regular_config = config
regular_config.att_sparse = False
regular_config.mlp_sparse = False
regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype)
regular_block.eval()
# inference simulation with select block
max_seqlen = seq_len + 16
max_batch_size = batch_size
in_features = args.in_features
head_dim = 128
inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size)
process_group = None
sequence_parallel = False
# for testing and debugging
heads = config.num_attention_heads
selected_heads = heads // 2
# Create a static index vector (length equals total columns in B).
total_neurons = args.in_features * 4
test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32)
active_indices = sparse_index(args.index_size, total_neurons)[0]
test_index_vec[:args.index_size] = active_indices
if args.index_size < total_neurons:
test_index_vec[args.index_size:] = 0 # Fill the rest with dummy values.
# test_index_vec = sparse_index(args.in_features, args.in_features*4)[0].cuda()
test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads)
test_index_size = args.index_size
mixer_kwargs = (
{"seqlen": seq_len}
if process_group is not None and sequence_parallel
else {}
)
if inference_params is not None:
mixer_kwargs["inference_params"] = inference_params
with torch.no_grad():
# prefill stage
original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16)
# Test prefill
output_sparse = sparse_block(original_seq, mixer_kwargs=mixer_kwargs)
output_regular = regular_block(original_seq, mixer_kwargs=mixer_kwargs)
# need to update inference_params to reflect the new sequence length
mixer_kwargs["inference_params"].seqlen_offset = seq_len
# Decode stage
input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16)
out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs)
mixer_kwargs["inference_params"].seqlen_offset = seq_len
out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs)
# mesure decode stage time in ms
# print("Without CUDA Graphs")
# out_decode_regular, regular_time = cuda_profiler(regular_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2)
# print(f"Regular time: {regular_time} ms")
# out_decode_sparse, sparse_time = cuda_profiler(sparse_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2)
# print(f"Sparse time: {sparse_time} ms")
# speedup = regular_time / sparse_time
# print(f"Speedup: {speedup}")
# --- CUDA Graph Capture for Decode Stage ---
# Allocate static buffer for regular block (shape assumed fixed)
input_x_static = input_x.clone()
output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype)
# Capture regular block graph
_ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs)
torch.cuda.synchronize()
graph_regular = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph_regular):
res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs)
if isinstance(res, tuple):
res = res[0]
output_regular_static.copy_(res)
# For the sparse block, run a dummy call to determine its output shape.
# Also, reset the inference parameter to ensure consistent behavior.
mixer_kwargs["inference_params"].seqlen_offset = seq_len
temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs)
if isinstance(temp, tuple):
temp = temp[0]
# print("Captured sparse block output shape:", temp.shape)
# Allocate static buffer matching the dummy run's shape.
output_sparse_static = torch.empty_like(temp)
# print("output_sparse_static shape:", output_sparse_static.shape)
torch.cuda.synchronize()
mixer_kwargs["inference_params"].seqlen_offset = seq_len
graph_sparse = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph_sparse):
res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs)
if isinstance(res, tuple):
res = res[0]
output_sparse_static.copy_(res)
# Warmup CUDA Graph replays
for _ in range(5):
graph_regular.replay()
graph_sparse.replay()
torch.cuda.synchronize()
# --- Measure CUDA Graph Replay Latency ---
num_replays = 10
start = time.time()
for _ in range(num_replays):
graph_regular.replay()
torch.cuda.synchronize()
regular_graph_time = (time.time() - start) * 1000 / num_replays
start = time.time()
for _ in range(num_replays):
graph_sparse.replay()
torch.cuda.synchronize()
sparse_graph_time = (time.time() - start) * 1000 / num_replays
speedup = regular_graph_time / sparse_graph_time
# print()
# print("With CUDA Graphs")
# print(f"Regular block time (CUDA Graphs): {regular_graph_time} ms")
# print(f"Sparse block time (CUDA Graphs): {sparse_graph_time} ms")
# print(f"Speedup (CUDA Graphs): {speedup}")
return regular_graph_time, sparse_graph_time, speedup
if __name__ == "__main__":
args = arg_parser()
device = _get_device(0)
dtype = torch.float16
gpu_name = get_gpu_name()
# Parameter grids.
# batch_sizes = [1, 4, 8, 16]
# seq_lengths = [128, 512]
# index_sizes = [512, 1024, 2048, 4096]
# attn_topks = [0.3, 0.4, 0.5]
batch_sizes = [1, 8, 16, 32]
seq_lengths = [1024, 2048]
# index_sizes = [512, 1024, 2048, 4096, 8192]
index_size_p = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
total_neurons = args.in_features * 4
# Calculate initial index_size values
index_sizes = [int(total_neurons * i) for i in index_size_p]
# Round up to the nearest multiple of 128 if necessary
index_sizes = [math.ceil(size / 128) * 128 if size % 128 != 0 else size for size in index_sizes]
attn_topks = [0.3, 0.4, 0.5]
# Calculate total number of simulations.
total_runs = len(batch_sizes) * len(seq_lengths) * len(index_sizes) * len(attn_topks)
output_file = f"results/simulations/{gpu_name}_select_block_{args.in_features}_inference_sim.csv"
with open(output_file, mode='w', newline='') as csv_file:
fieldnames = ["in_features", "batch_size", "seq_len", "index_size", "neuron_activation", "attn_topk",
"regular_graph_time_ms", "sparse_graph_time_ms", "speedup"]
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()
# Iterate over all combinations with tqdm progress bar.
for batch_size in tqdm(batch_sizes, desc="Batch Sizes"):
for seq_len in seq_lengths:
for index_size in index_sizes:
for attn_topk in attn_topks:
reg_time, spa_time, speedup = run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype)
writer.writerow({
"in_features": args.in_features,
"batch_size": batch_size,
"seq_len": seq_len,
"index_size": index_size,
"neuron_activation": index_size / total_neurons,
"attn_topk": attn_topk,
"regular_graph_time_ms": reg_time,
"sparse_graph_time_ms": spa_time,
"speedup": speedup
})
csv_file.flush()
print(f"Simulation complete. Results saved to {output_file}") |