|
|
|
|
|
|
|
|
import torch |
|
|
import time |
|
|
from HybridTensor.utils.utils import arg_parser, generate_random_BH_index |
|
|
from HybridTensor.utils.profiling import cuda_profiler |
|
|
from HybridTensor.utils.generation import InferenceParams |
|
|
from HybridTensor.utils.utils import sparse_index |
|
|
from HybridTensor.utils.utils import _get_device |
|
|
from HybridTensor.models.create_sparse_model import create_block |
|
|
|
|
|
class Config: |
|
|
def __init__(self, in_features=8192): |
|
|
self.hidden_size = in_features |
|
|
self.num_attention_heads = in_features // 128 |
|
|
self.head_dim = self.hidden_size // self.num_attention_heads |
|
|
self.scale_attn_weights = True |
|
|
self.mup_scale_qk_dot_by_d = False |
|
|
self.mup_attn_multiplier = 1.0 |
|
|
self.scale_attn_by_inverse_layer_idx = False |
|
|
self.attn_dwconv = False |
|
|
self.qkv_proj_bias = True |
|
|
self.out_proj_bias = True |
|
|
self.rotary_emb_fraction = 0.0 |
|
|
self.rotary_emb_base = 10000.0 |
|
|
self.rotary_emb_scale_base = None |
|
|
self.rotary_emb_interleaved = False |
|
|
self.use_alibi = False |
|
|
self.window_size = (-1, -1) |
|
|
self.use_flash_attn = True |
|
|
self.fused_bias_fc = True |
|
|
self.mlp_sparse = True |
|
|
self.att_sparse = True |
|
|
self.attn_pdrop = 0.1 |
|
|
self.n_inner = None |
|
|
self.activation_function = "relu" |
|
|
self.fused_mlp = True |
|
|
self.mlp_checkpoint_lvl = 0 |
|
|
self.sequence_parallel = False |
|
|
self.layer_norm_epsilon = 1e-5 |
|
|
self.residual_in_fp32 = False |
|
|
self.fused_dropout_add_ln = True |
|
|
self.resid_pdrop = 0.1 |
|
|
self.embd_pdrop = 0.1 |
|
|
self.prenorm = True |
|
|
self.parallel_block = False |
|
|
|
|
|
class SparseConfig: |
|
|
def __init__(self): |
|
|
self.mlp_low_rank_dim = 1024 |
|
|
self.attn_low_rank_dim = 128 |
|
|
self.attn_topk = 0.5 |
|
|
|
|
|
if __name__ =="__main__": |
|
|
|
|
|
args = arg_parser() |
|
|
|
|
|
config = Config() |
|
|
sp_config = SparseConfig() |
|
|
sp_config.attn_topk = args.attn_topk |
|
|
|
|
|
config.hidden_size = args.in_features |
|
|
config.num_attention_heads = args.in_features // 128 |
|
|
config.use_heuristic = False |
|
|
|
|
|
|
|
|
device = _get_device(args.device) |
|
|
dtype = torch.float16 |
|
|
|
|
|
|
|
|
sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype) |
|
|
sparse_block.eval() |
|
|
sparse_block.mlp_topk = args.index_size |
|
|
sparse_block.mlp.use_heuristic = False |
|
|
|
|
|
regular_config = config |
|
|
regular_config.att_sparse = False |
|
|
regular_config.mlp_sparse = False |
|
|
regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype) |
|
|
regular_block.eval() |
|
|
|
|
|
|
|
|
max_seqlen = args.seq_len + 128 |
|
|
max_batch_size = args.batch_size |
|
|
in_features = args.in_features |
|
|
head_dim = 128 |
|
|
batch_size = args.batch_size |
|
|
seq_len = args.seq_len |
|
|
index_size = args.index_size |
|
|
|
|
|
inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size) |
|
|
process_group = None |
|
|
sequence_parallel = False |
|
|
|
|
|
|
|
|
heads = config.num_attention_heads |
|
|
selected_heads = heads // 2 |
|
|
|
|
|
|
|
|
total_neurons = args.in_features * 4 |
|
|
test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32) |
|
|
active_indices = sparse_index(args.index_size, total_neurons)[0] |
|
|
test_index_vec[:args.index_size] = active_indices |
|
|
if args.index_size < total_neurons: |
|
|
test_index_vec[args.index_size:] = 0 |
|
|
|
|
|
|
|
|
test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads) |
|
|
test_index_size = args.index_size |
|
|
|
|
|
mixer_kwargs = ( |
|
|
{"seqlen": seq_len} |
|
|
if process_group is not None and sequence_parallel |
|
|
else {} |
|
|
) |
|
|
if inference_params is not None: |
|
|
mixer_kwargs["inference_params"] = inference_params |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kv = torch.rand(batch_size, seq_len, 2, heads, head_dim, device='cuda', dtype=torch.float16) |
|
|
|
|
|
|
|
|
sparse_block.mixer._update_kv_cache(kv, inference_params) |
|
|
regular_block.mixer._update_kv_cache(kv, inference_params) |
|
|
mixer_kwargs["inference_params"].seqlen_offset = seq_len |
|
|
|
|
|
|
|
|
input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16) |
|
|
|
|
|
out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs) |
|
|
|
|
|
mixer_kwargs["inference_params"].seqlen_offset = seq_len |
|
|
|
|
|
out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs) |
|
|
|
|
|
|
|
|
print("Without CUDA Graphs") |
|
|
out_decode_regular, regular_time = cuda_profiler(regular_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2) |
|
|
print(f"Regular time: {regular_time} ms") |
|
|
|
|
|
out_decode_sparse, sparse_time = cuda_profiler(sparse_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2) |
|
|
print(f"Sparse time: {sparse_time} ms") |
|
|
|
|
|
speedup = regular_time / sparse_time |
|
|
print(f"Speedup: {speedup}") |
|
|
|
|
|
|
|
|
|
|
|
input_x_static = input_x.clone() |
|
|
output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
_ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) |
|
|
torch.cuda.synchronize() |
|
|
graph_regular = torch.cuda.CUDAGraph() |
|
|
with torch.cuda.graph(graph_regular): |
|
|
res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) |
|
|
if isinstance(res, tuple): |
|
|
res = res[0] |
|
|
output_regular_static.copy_(res) |
|
|
|
|
|
|
|
|
|
|
|
mixer_kwargs["inference_params"].seqlen_offset = seq_len |
|
|
temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) |
|
|
if isinstance(temp, tuple): |
|
|
temp = temp[0] |
|
|
|
|
|
|
|
|
output_sparse_static = torch.empty_like(temp) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
mixer_kwargs["inference_params"].seqlen_offset = seq_len |
|
|
graph_sparse = torch.cuda.CUDAGraph() |
|
|
with torch.cuda.graph(graph_sparse): |
|
|
res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) |
|
|
if isinstance(res, tuple): |
|
|
res = res[0] |
|
|
output_sparse_static.copy_(res) |
|
|
|
|
|
|
|
|
for _ in range(5): |
|
|
graph_regular.replay() |
|
|
graph_sparse.replay() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
num_replays = 10 |
|
|
|
|
|
start = time.time() |
|
|
for _ in range(num_replays): |
|
|
graph_regular.replay() |
|
|
torch.cuda.synchronize() |
|
|
regular_graph_time = (time.time() - start) * 1000 / num_replays |
|
|
|
|
|
start = time.time() |
|
|
for _ in range(num_replays): |
|
|
graph_sparse.replay() |
|
|
torch.cuda.synchronize() |
|
|
sparse_graph_time = (time.time() - start) * 1000 / num_replays |
|
|
|
|
|
print() |
|
|
print("With CUDA Graphs") |
|
|
print(f"Regular block time (CUDA Graphs): {regular_graph_time} ms") |
|
|
print(f"Sparse block time (CUDA Graphs): {sparse_graph_time} ms") |
|
|
print(f"Speedup (CUDA Graphs): {regular_graph_time/sparse_graph_time}") |
|
|
|
|
|
|
|
|
if args.check_results: |
|
|
if isinstance(out_decode_regular, tuple): |
|
|
out_decode_regular = out_decode_regular[0] |
|
|
regular_match = torch.allclose(out_decode_regular, output_regular_static, rtol=1e-3, atol=1e-5) |
|
|
reg_diff = (out_decode_regular - output_regular_static).abs().max() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nComparison for Regular Block:") |
|
|
print(f"Outputs match: {regular_match}") |
|
|
print(f"Max difference: {reg_diff}") |
|
|
|
|
|
if isinstance(out_decode_sparse, tuple): |
|
|
out_decode_sparse = out_decode_sparse[0] |
|
|
sparse_match = torch.allclose(out_decode_sparse, output_sparse_static, rtol=1e-3, atol=1e-5) |
|
|
spa_diff = (out_decode_sparse - output_sparse_static).abs().max() |
|
|
print("\nComparison for Sparse Block:") |
|
|
print(f"Outputs match: {sparse_match}") |
|
|
print(f"Max difference: {spa_diff}") |
|
|
|
|
|
|
|
|
|