|
|
import torch |
|
|
|
|
|
from tests.test_select_block import create_block, Config, SparseConfig |
|
|
import csv |
|
|
import time |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from flash_attn.utils.generation import InferenceParams |
|
|
from HybridTensor.utils.utils import arg_parser, _get_device, sparse_index, generate_random_BH_index, get_gpu_name |
|
|
from HybridTensor.utils.profiling import cuda_profiler |
|
|
import math |
|
|
from tqdm import tqdm |
|
|
|
|
|
def run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype): |
|
|
config = Config() |
|
|
sp_config = SparseConfig() |
|
|
sp_config.attn_topk = attn_topk |
|
|
|
|
|
config.hidden_size = args.in_features |
|
|
config.num_attention_heads = args.in_features // 128 |
|
|
config.use_heuristic = False |
|
|
|
|
|
|
|
|
sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype) |
|
|
sparse_block.eval() |
|
|
sparse_block.mlp_topk = index_size |
|
|
|
|
|
regular_config = config |
|
|
regular_config.att_sparse = False |
|
|
regular_config.mlp_sparse = False |
|
|
regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype) |
|
|
regular_block.eval() |
|
|
|
|
|
|
|
|
max_seqlen = seq_len + 16 |
|
|
max_batch_size = batch_size |
|
|
in_features = args.in_features |
|
|
head_dim = 128 |
|
|
|
|
|
inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size) |
|
|
process_group = None |
|
|
sequence_parallel = False |
|
|
|
|
|
|
|
|
heads = config.num_attention_heads |
|
|
selected_heads = heads // 2 |
|
|
|
|
|
|
|
|
total_neurons = args.in_features * 4 |
|
|
test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32) |
|
|
active_indices = sparse_index(args.index_size, total_neurons)[0] |
|
|
test_index_vec[:args.index_size] = active_indices |
|
|
if args.index_size < total_neurons: |
|
|
test_index_vec[args.index_size:] = 0 |
|
|
|
|
|
|
|
|
test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads) |
|
|
test_index_size = args.index_size |
|
|
|
|
|
mixer_kwargs = ( |
|
|
{"seqlen": seq_len} |
|
|
if process_group is not None and sequence_parallel |
|
|
else {} |
|
|
) |
|
|
if inference_params is not None: |
|
|
mixer_kwargs["inference_params"] = inference_params |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16) |
|
|
|
|
|
|
|
|
output_sparse = sparse_block(original_seq, mixer_kwargs=mixer_kwargs) |
|
|
output_regular = regular_block(original_seq, mixer_kwargs=mixer_kwargs) |
|
|
|
|
|
|
|
|
mixer_kwargs["inference_params"].seqlen_offset = seq_len |
|
|
|
|
|
|
|
|
input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16) |
|
|
|
|
|
out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs) |
|
|
|
|
|
mixer_kwargs["inference_params"].seqlen_offset = seq_len |
|
|
|
|
|
out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_x_static = input_x.clone() |
|
|
output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
_ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) |
|
|
torch.cuda.synchronize() |
|
|
graph_regular = torch.cuda.CUDAGraph() |
|
|
with torch.cuda.graph(graph_regular): |
|
|
res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) |
|
|
if isinstance(res, tuple): |
|
|
res = res[0] |
|
|
output_regular_static.copy_(res) |
|
|
|
|
|
|
|
|
|
|
|
mixer_kwargs["inference_params"].seqlen_offset = seq_len |
|
|
temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) |
|
|
if isinstance(temp, tuple): |
|
|
temp = temp[0] |
|
|
|
|
|
|
|
|
output_sparse_static = torch.empty_like(temp) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
mixer_kwargs["inference_params"].seqlen_offset = seq_len |
|
|
graph_sparse = torch.cuda.CUDAGraph() |
|
|
with torch.cuda.graph(graph_sparse): |
|
|
res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) |
|
|
if isinstance(res, tuple): |
|
|
res = res[0] |
|
|
output_sparse_static.copy_(res) |
|
|
|
|
|
|
|
|
for _ in range(5): |
|
|
graph_regular.replay() |
|
|
graph_sparse.replay() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
num_replays = 10 |
|
|
|
|
|
start = time.time() |
|
|
for _ in range(num_replays): |
|
|
graph_regular.replay() |
|
|
torch.cuda.synchronize() |
|
|
regular_graph_time = (time.time() - start) * 1000 / num_replays |
|
|
|
|
|
start = time.time() |
|
|
for _ in range(num_replays): |
|
|
graph_sparse.replay() |
|
|
torch.cuda.synchronize() |
|
|
sparse_graph_time = (time.time() - start) * 1000 / num_replays |
|
|
speedup = regular_graph_time / sparse_graph_time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return regular_graph_time, sparse_graph_time, speedup |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
args = arg_parser() |
|
|
device = _get_device(0) |
|
|
dtype = torch.float16 |
|
|
gpu_name = get_gpu_name() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_sizes = [1, 8, 16, 32] |
|
|
seq_lengths = [1024, 2048] |
|
|
|
|
|
index_size_p = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5] |
|
|
total_neurons = args.in_features * 4 |
|
|
|
|
|
|
|
|
index_sizes = [int(total_neurons * i) for i in index_size_p] |
|
|
|
|
|
|
|
|
index_sizes = [math.ceil(size / 128) * 128 if size % 128 != 0 else size for size in index_sizes] |
|
|
|
|
|
attn_topks = [0.3, 0.4, 0.5] |
|
|
|
|
|
|
|
|
total_runs = len(batch_sizes) * len(seq_lengths) * len(index_sizes) * len(attn_topks) |
|
|
output_file = f"results/simulations/{gpu_name}_select_block_{args.in_features}_inference_sim.csv" |
|
|
|
|
|
with open(output_file, mode='w', newline='') as csv_file: |
|
|
fieldnames = ["in_features", "batch_size", "seq_len", "index_size", "neuron_activation", "attn_topk", |
|
|
"regular_graph_time_ms", "sparse_graph_time_ms", "speedup"] |
|
|
writer = csv.DictWriter(csv_file, fieldnames=fieldnames) |
|
|
writer.writeheader() |
|
|
|
|
|
|
|
|
for batch_size in tqdm(batch_sizes, desc="Batch Sizes"): |
|
|
for seq_len in seq_lengths: |
|
|
for index_size in index_sizes: |
|
|
for attn_topk in attn_topks: |
|
|
reg_time, spa_time, speedup = run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype) |
|
|
writer.writerow({ |
|
|
"in_features": args.in_features, |
|
|
"batch_size": batch_size, |
|
|
"seq_len": seq_len, |
|
|
"index_size": index_size, |
|
|
"neuron_activation": index_size / total_neurons, |
|
|
"attn_topk": attn_topk, |
|
|
"regular_graph_time_ms": reg_time, |
|
|
"sparse_graph_time_ms": spa_time, |
|
|
"speedup": speedup |
|
|
}) |
|
|
csv_file.flush() |
|
|
print(f"Simulation complete. Results saved to {output_file}") |