| | |
| |
|
| | import torch |
| | import time |
| | from HybridTensor.utils.utils import arg_parser, generate_random_BH_index |
| | from HybridTensor.utils.profiling import cuda_profiler |
| | from HybridTensor.utils.generation import InferenceParams |
| | from HybridTensor.utils.utils import sparse_index |
| | from HybridTensor.utils.utils import _get_device |
| | from HybridTensor.models.create_sparse_model import create_block |
| |
|
| | class Config: |
| | def __init__(self, in_features=8192): |
| | self.hidden_size = in_features |
| | self.num_attention_heads = in_features // 128 |
| | self.head_dim = self.hidden_size // self.num_attention_heads |
| | self.scale_attn_weights = True |
| | self.mup_scale_qk_dot_by_d = False |
| | self.mup_attn_multiplier = 1.0 |
| | self.scale_attn_by_inverse_layer_idx = False |
| | self.attn_dwconv = False |
| | self.qkv_proj_bias = True |
| | self.out_proj_bias = True |
| | self.rotary_emb_fraction = 0.0 |
| | self.rotary_emb_base = 10000.0 |
| | self.rotary_emb_scale_base = None |
| | self.rotary_emb_interleaved = False |
| | self.use_alibi = False |
| | self.window_size = (-1, -1) |
| | self.use_flash_attn = True |
| | self.fused_bias_fc = True |
| | self.mlp_sparse = True |
| | self.att_sparse = True |
| | self.attn_pdrop = 0.1 |
| | self.n_inner = None |
| | self.activation_function = "relu" |
| | self.fused_mlp = True |
| | self.mlp_checkpoint_lvl = 0 |
| | self.sequence_parallel = False |
| | self.layer_norm_epsilon = 1e-5 |
| | self.residual_in_fp32 = False |
| | self.fused_dropout_add_ln = True |
| | self.resid_pdrop = 0.1 |
| | self.embd_pdrop = 0.1 |
| | self.prenorm = True |
| | self.parallel_block = False |
| |
|
| | class SparseConfig: |
| | def __init__(self): |
| | self.mlp_low_rank_dim = 1024 |
| | self.attn_low_rank_dim = 128 |
| | self.attn_topk = 0.5 |
| | |
| | if __name__ =="__main__": |
| | |
| | args = arg_parser() |
| | |
| | config = Config() |
| | sp_config = SparseConfig() |
| | sp_config.attn_topk = args.attn_topk |
| | |
| | config.hidden_size = args.in_features |
| | config.num_attention_heads = args.in_features // 128 |
| | config.use_heuristic = False |
| | |
| | |
| | device = _get_device(args.device) |
| | dtype = torch.float16 |
| |
|
| | |
| | sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype) |
| | sparse_block.eval() |
| | sparse_block.mlp_topk = args.index_size |
| | sparse_block.mlp.use_heuristic = False |
| | |
| | regular_config = config |
| | regular_config.att_sparse = False |
| | regular_config.mlp_sparse = False |
| | regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype) |
| | regular_block.eval() |
| | |
| | |
| | max_seqlen = args.seq_len + 128 |
| | max_batch_size = args.batch_size |
| | in_features = args.in_features |
| | head_dim = 128 |
| | batch_size = args.batch_size |
| | seq_len = args.seq_len |
| | index_size = args.index_size |
| | |
| | inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size) |
| | process_group = None |
| | sequence_parallel = False |
| | |
| | |
| | heads = config.num_attention_heads |
| | selected_heads = heads // 2 |
| | |
| | |
| | total_neurons = args.in_features * 4 |
| | test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32) |
| | active_indices = sparse_index(args.index_size, total_neurons)[0] |
| | test_index_vec[:args.index_size] = active_indices |
| | if args.index_size < total_neurons: |
| | test_index_vec[args.index_size:] = 0 |
| | |
| | |
| | test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads) |
| | test_index_size = args.index_size |
| | |
| | mixer_kwargs = ( |
| | {"seqlen": seq_len} |
| | if process_group is not None and sequence_parallel |
| | else {} |
| | ) |
| | if inference_params is not None: |
| | mixer_kwargs["inference_params"] = inference_params |
| | |
| | with torch.no_grad(): |
| | |
| | original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16) |
| | |
| | |
| | |
| | |
| | |
| | |
| | kv = torch.rand(batch_size, seq_len, 2, heads, head_dim, device='cuda', dtype=torch.float16) |
| | |
| | |
| | sparse_block.mixer._update_kv_cache(kv, inference_params) |
| | regular_block.mixer._update_kv_cache(kv, inference_params) |
| | mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| |
|
| | |
| | input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16) |
| | |
| | out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs) |
| | |
| | mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| | |
| | out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs) |
| | |
| | |
| | print("Without CUDA Graphs") |
| | out_decode_regular, regular_time = cuda_profiler(regular_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2) |
| | print(f"Regular time: {regular_time} ms") |
| | |
| | out_decode_sparse, sparse_time = cuda_profiler(sparse_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2) |
| | print(f"Sparse time: {sparse_time} ms") |
| | |
| | speedup = regular_time / sparse_time |
| | print(f"Speedup: {speedup}") |
| | |
| | |
| | |
| | input_x_static = input_x.clone() |
| | output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype) |
| |
|
| | |
| | _ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| | torch.cuda.synchronize() |
| | graph_regular = torch.cuda.CUDAGraph() |
| | with torch.cuda.graph(graph_regular): |
| | res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| | if isinstance(res, tuple): |
| | res = res[0] |
| | output_regular_static.copy_(res) |
| |
|
| | |
| | |
| | mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| | temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| | if isinstance(temp, tuple): |
| | temp = temp[0] |
| | |
| | |
| | output_sparse_static = torch.empty_like(temp) |
| | |
| | torch.cuda.synchronize() |
| | |
| | mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| | graph_sparse = torch.cuda.CUDAGraph() |
| | with torch.cuda.graph(graph_sparse): |
| | res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| | if isinstance(res, tuple): |
| | res = res[0] |
| | output_sparse_static.copy_(res) |
| |
|
| | |
| | for _ in range(5): |
| | graph_regular.replay() |
| | graph_sparse.replay() |
| | torch.cuda.synchronize() |
| |
|
| | |
| | num_replays = 10 |
| |
|
| | start = time.time() |
| | for _ in range(num_replays): |
| | graph_regular.replay() |
| | torch.cuda.synchronize() |
| | regular_graph_time = (time.time() - start) * 1000 / num_replays |
| |
|
| | start = time.time() |
| | for _ in range(num_replays): |
| | graph_sparse.replay() |
| | torch.cuda.synchronize() |
| | sparse_graph_time = (time.time() - start) * 1000 / num_replays |
| |
|
| | print() |
| | print("With CUDA Graphs") |
| | print(f"Regular block time (CUDA Graphs): {regular_graph_time} ms") |
| | print(f"Sparse block time (CUDA Graphs): {sparse_graph_time} ms") |
| | print(f"Speedup (CUDA Graphs): {regular_graph_time/sparse_graph_time}") |
| | |
| | |
| | if args.check_results: |
| | if isinstance(out_decode_regular, tuple): |
| | out_decode_regular = out_decode_regular[0] |
| | regular_match = torch.allclose(out_decode_regular, output_regular_static, rtol=1e-3, atol=1e-5) |
| | reg_diff = (out_decode_regular - output_regular_static).abs().max() |
| | |
| | |
| | |
| | |
| | print("\nComparison for Regular Block:") |
| | print(f"Outputs match: {regular_match}") |
| | print(f"Max difference: {reg_diff}") |
| |
|
| | if isinstance(out_decode_sparse, tuple): |
| | out_decode_sparse = out_decode_sparse[0] |
| | sparse_match = torch.allclose(out_decode_sparse, output_sparse_static, rtol=1e-3, atol=1e-5) |
| | spa_diff = (out_decode_sparse - output_sparse_static).abs().max() |
| | print("\nComparison for Sparse Block:") |
| | print(f"Outputs match: {sparse_match}") |
| | print(f"Max difference: {spa_diff}") |
| |
|
| |
|
| |
|