Susav commited on
Commit
b3a3b15
·
verified ·
1 Parent(s): 2453355

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. HybridTensor/__init__.py +0 -0
  2. HybridTensor/__pycache__/__init__.cpython-310.pyc +0 -0
  3. HybridTensor/__pycache__/__init__.cpython-39.pyc +0 -0
  4. HybridTensor/benchmarks/generation/__pycache__/gen_util.cpython-39.pyc +0 -0
  5. HybridTensor/benchmarks/generation/__pycache__/llama_sparse_generation.cpython-39.pyc +0 -0
  6. HybridTensor/benchmarks/generation/__pycache__/model_sparse_generation.cpython-39.pyc +0 -0
  7. HybridTensor/benchmarks/generation/__pycache__/opt_gen_tp.cpython-39.pyc +0 -0
  8. HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-310.pyc +0 -0
  9. HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-39.pyc +0 -0
  10. HybridTensor/benchmarks/generation/__pycache__/opt_sparse_gen_tp.cpython-39.pyc +0 -0
  11. HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-310.pyc +0 -0
  12. HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-39.pyc +0 -0
  13. HybridTensor/benchmarks/generation/gen_util.py +38 -0
  14. HybridTensor/benchmarks/generation/model_sparse_generation.py +71 -0
  15. HybridTensor/benchmarks/generation/opt_gen_tp.py +133 -0
  16. HybridTensor/benchmarks/generation/opt_generation.py +289 -0
  17. HybridTensor/benchmarks/generation/opt_sparse_gen_tp.py +112 -0
  18. HybridTensor/benchmarks/generation/opt_sparse_generation.py +182 -0
  19. HybridTensor/benchmarks/model_eval.py +313 -0
  20. HybridTensor/benchmarks/model_perplexity.py +165 -0
  21. HybridTensor/benchmarks/opt_attn_sparse_topk_perplexity.py +264 -0
  22. HybridTensor/benchmarks/select_block_decode.py +218 -0
  23. HybridTensor/models/__pycache__/create_sparse_model.cpython-310.pyc +0 -0
  24. HybridTensor/models/__pycache__/create_sparse_model.cpython-39.pyc +0 -0
  25. HybridTensor/models/__pycache__/helper.cpython-310.pyc +0 -0
  26. HybridTensor/models/__pycache__/helper.cpython-39.pyc +0 -0
  27. HybridTensor/models/__pycache__/llama.cpython-39.pyc +0 -0
  28. HybridTensor/models/__pycache__/opt.cpython-310.pyc +0 -0
  29. HybridTensor/models/__pycache__/opt.cpython-39.pyc +0 -0
  30. HybridTensor/models/create_sparse_model.py +854 -0
  31. HybridTensor/models/helper.py +125 -0
  32. HybridTensor/models/llama.py +74 -0
  33. HybridTensor/models/opt.py +229 -0
  34. HybridTensor/modules/SelectiveBlock.py +960 -0
  35. HybridTensor/modules/SelectiveMHA.py +1579 -0
  36. HybridTensor/modules/SelectiveMLP.py +580 -0
  37. HybridTensor/modules/SelectiveRouters.py +136 -0
  38. HybridTensor/modules/__init__.py +0 -0
  39. HybridTensor/modules/__pycache__/MLP.cpython-39.pyc +0 -0
  40. HybridTensor/modules/__pycache__/ParallelMLP.cpython-39.pyc +0 -0
  41. HybridTensor/modules/__pycache__/SelectiveBlock.cpython-39.pyc +0 -0
  42. HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-310.pyc +0 -0
  43. HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-39.pyc +0 -0
  44. HybridTensor/modules/__pycache__/SelectiveMHA.cpython-310.pyc +0 -0
  45. HybridTensor/modules/__pycache__/SelectiveMHA.cpython-39.pyc +0 -0
  46. HybridTensor/modules/__pycache__/SelectiveMLP.cpython-310.pyc +0 -0
  47. HybridTensor/modules/__pycache__/SelectiveMLP.cpython-39.pyc +0 -0
  48. HybridTensor/modules/__pycache__/SelectiveRouters.cpython-310.pyc +0 -0
  49. HybridTensor/modules/__pycache__/SelectiveRouters.cpython-39.pyc +0 -0
  50. HybridTensor/modules/__pycache__/__init__.cpython-310.pyc +0 -0
HybridTensor/__init__.py ADDED
File without changes
HybridTensor/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (190 Bytes). View file
 
HybridTensor/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (171 Bytes). View file
 
HybridTensor/benchmarks/generation/__pycache__/gen_util.cpython-39.pyc ADDED
Binary file (869 Bytes). View file
 
HybridTensor/benchmarks/generation/__pycache__/llama_sparse_generation.cpython-39.pyc ADDED
Binary file (2.34 kB). View file
 
HybridTensor/benchmarks/generation/__pycache__/model_sparse_generation.cpython-39.pyc ADDED
Binary file (2.68 kB). View file
 
HybridTensor/benchmarks/generation/__pycache__/opt_gen_tp.cpython-39.pyc ADDED
Binary file (3.75 kB). View file
 
HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-310.pyc ADDED
Binary file (5.22 kB). View file
 
HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-39.pyc ADDED
Binary file (6.12 kB). View file
 
HybridTensor/benchmarks/generation/__pycache__/opt_sparse_gen_tp.cpython-39.pyc ADDED
Binary file (3.44 kB). View file
 
HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-310.pyc ADDED
Binary file (3.36 kB). View file
 
HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-39.pyc ADDED
Binary file (4.56 kB). View file
 
HybridTensor/benchmarks/generation/gen_util.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from datasets import load_dataset
4
+ from transformers import AutoTokenizer
5
+
6
+ def tokenize_dataset(dataset, tokenizer):
7
+ # Tokenize and concatenate all texts (without adding special tokens)
8
+ all_tokens = []
9
+ for example in dataset:
10
+ tokens = tokenizer(example["text"], add_special_tokens=False)["input_ids"]
11
+ all_tokens.extend(tokens)
12
+ return all_tokens
13
+
14
+ def get_random_batch(tokens, batch_size, seq_length):
15
+ total = len(tokens)
16
+ batch = []
17
+ for _ in range(batch_size):
18
+ start = random.randint(0, total - seq_length)
19
+ batch.append(tokens[start : start + seq_length])
20
+ return torch.tensor(batch)
21
+
22
+ '''
23
+ # Load dataset and tokenizer
24
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
25
+ model_name = "facebook/opt-6.7b"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+
28
+ tokens = tokenize_dataset(dataset, tokenizer)
29
+
30
+ # Define parameters
31
+ batch_size = 8
32
+ seq_length = 2000
33
+
34
+ random_batch = get_random_batch(tokens, batch_size, seq_length)
35
+ print("Batch shape:", random_batch.shape) # Expected: (8, 128)
36
+
37
+ '''
38
+
HybridTensor/benchmarks/generation/model_sparse_generation.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+
4
+ from HybridTensor.utils.utils import _get_device
5
+ from HybridTensor.utils.activations import MODELS
6
+ from HybridTensor.models.opt import build_sparse_opt
7
+ from HybridTensor.models.llama import build_sparse_llama
8
+ from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv
9
+ from transformers import AutoTokenizer
10
+
11
+
12
+ def update_router_config(model, num_layers, mlp_topk_lookup, attn_topk):
13
+ for i in range(num_layers):
14
+ if mlp_topk_lookup is not None:
15
+ model.transformer.layers[i].mlp_topk = mlp_topk_lookup[i]
16
+ # model.transformer.layers[i].mlp_topk = 512
17
+ model.transformer.layers[i].mha_router.topk = attn_topk
18
+
19
+ # dense attention in layer 0
20
+ model.transformer.layers[0].mha_router.topk = 1.0
21
+
22
+ def arg_parser():
23
+ parser = argparse.ArgumentParser(description='Inference benchmarking')
24
+ parser.add_argument('--batch_size', type=int, default=16)
25
+ parser.add_argument('--model_index', type=int, default=5)
26
+ parser.add_argument('--print_results', type=bool, default=True)
27
+ parser.add_argument('--iterations', type=int, default=1)
28
+ parser.add_argument('--gpu', type=int, default=0)
29
+ parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model')
30
+ parser.add_argument('--mlp_ckpt_dir', type=str, default='/home/grads/s/<name>/nvme/HybridTensor/checkpoint/opt-6.7b-routers/mlp')
31
+ parser.add_argument('--attn_ckpt_dir', type=str, default='/home/grads/s/<name>/nvme/HybridTensor/checkpoint/opt-6.7b-routers/mha_linear')
32
+ parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router/opt-6.7b')
33
+ parser.add_argument('--delta', type=int, default=256, help='Delta value for MLP topk calculation')
34
+ parser.add_argument('--use_cuda_graph', type=bool, default=False, help='Use CUDA graph for inference')
35
+
36
+ return parser.parse_args()
37
+
38
+ if __name__ == "__main__":
39
+ args = arg_parser()
40
+ model_name = MODELS[args.model_index-1]
41
+ print(f"Model name: {model_name}")
42
+ dtype = torch.float16
43
+ device= _get_device(args.gpu)
44
+
45
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+ if "llama" in model_name:
47
+ model = build_sparse_llama(args, model_name,
48
+ args.attn_ckpt_dir,
49
+ device = device, dtype=dtype)
50
+ update_router_config(model, model.config.n_layer, None, args.attn_topk) # this sets the router config for all layers using a single config
51
+
52
+ else:
53
+ mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, args.batch_size)
54
+ model = build_sparse_opt(args, model_name,
55
+ args.mlp_ckpt_dir,
56
+ args.attn_ckpt_dir,
57
+ device = device, dtype=dtype)
58
+ update_router_config(model, model.config.n_layer, mlp_topk_lookup, args.attn_topk) # this sets the router config for all layers using a single config
59
+
60
+
61
+ model.eval()
62
+ print(model)
63
+
64
+ # test input
65
+ input_text = "Once upon a time in a land far, far away, there lived a"
66
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
67
+
68
+ # Generate output
69
+ with torch.no_grad():
70
+ output = model.generate(input_ids, max_length=50)
71
+ print(tokenizer.decode(output[0], skip_special_tokens=True))
HybridTensor/benchmarks/generation/opt_gen_tp.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.opt import OPTConfig
2
+ from transformers import AutoTokenizer
3
+ from flash_attn.models.opt import opt_config_to_gpt2_config
4
+
5
+ import os
6
+ import torch
7
+ import argparse
8
+ from apex.transformer import parallel_state
9
+
10
+ from HybridTensor.utils.utils import arg_parser, _get_device
11
+ from HybridTensor.utils.activations import OPT_MODELS
12
+ from HybridTensor.models.opt import SparseConfig, build_sparse_opt, build_dense_opt
13
+
14
+
15
+ def initialize_distributed_environment():
16
+ # Set environment variables for NCCL
17
+ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
18
+ os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0"
19
+
20
+ # Initialize the distributed process group
21
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
22
+
23
+ # Set the device based on the rank of the current process
24
+ device = f"cuda:{torch.distributed.get_rank()}"
25
+ world_size = torch.distributed.get_world_size()
26
+
27
+ # Set the current CUDA device to avoid operations being executed on the wrong GPU
28
+ torch.cuda.set_device(device)
29
+
30
+ # You can return device, world_size, and any other relevant information
31
+ return device, world_size
32
+
33
+ def _turn_bias_off(model, num_layers):
34
+ for i in range(num_layers):
35
+ model.transformer.layers[i].mlp.fc1.bias = None
36
+ model.transformer.layers[i].mlp.fc2.bias = None
37
+
38
+ def arg_parser():
39
+ parser = argparse.ArgumentParser(description='Inference benchmarking')
40
+ parser.add_argument('--batch_size', type=int, default=128)
41
+ parser.add_argument('--model_index', type=int, default=5)
42
+ parser.add_argument('--seq_len', type=int, default=25)
43
+ parser.add_argument('--index_size', type=int, default=8192)
44
+ parser.add_argument('--head_density', type=float, default=0.25)
45
+ parser.add_argument('--print_results', type=bool, default=True)
46
+ parser.add_argument('--iterations', type=int, default=2)
47
+ parser.add_argument('--check_results', type=bool, default=False)
48
+ parser.add_argument('--results_dir', type=str, default='results')
49
+ parser.add_argument('--gpu', type=int, default=0)
50
+ parser.add_argument('--bias', type=bool, default=False)
51
+ parser.add_argument('--mlp_ckpt_dir', type=str, default='/home/grads/s/<name>/nvme/HybridTensor/checkpoint/opt-6.7b-routers/mlp')
52
+ parser.add_argument('--attn_ckpt_dir', type=str, default='/home/grads/s/<name>/nvme/HybridTensor/checkpoint/opt-6.7b-routers/mha_linear')
53
+
54
+ return parser.parse_args()
55
+
56
+ if __name__ == "__main__":
57
+
58
+ args = arg_parser()
59
+ model_name = OPT_MODELS[args.model_index-1]
60
+
61
+ device, world_size = initialize_distributed_environment()
62
+ dtype = torch.float16
63
+
64
+ parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
65
+ rank = parallel_state.get_tensor_model_parallel_rank()
66
+ process_group = parallel_state.get_tensor_model_parallel_group()
67
+
68
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
69
+ # model = build_sparse_opt(model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype, process_group = process_group, world_size = world_size, rank = rank)
70
+ model = build_dense_opt(model_name, process_group = process_group, world_size = world_size, rank = rank, device = device, dtype=dtype)
71
+ model.eval()
72
+ # if rank == 0:
73
+ # print(model)
74
+
75
+ # input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "]
76
+ input_texts = ["In a distant galaxy, a spaceship"]
77
+ tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(device)
78
+ input_ids=tokenized_inputs["input_ids"]
79
+ # input_ids = tokenizer("Hello, my dog is cute and", return_tensors="pt").input_ids.to(device=device)
80
+
81
+ max_length = args.seq_len
82
+ position_ids = None
83
+ eos_token_id = tokenizer.eos_token_id
84
+ num_layers = model.config.n_layer
85
+
86
+ # turn bias off for mlp layers
87
+ if not args.bias:
88
+ _turn_bias_off(model, num_layers)
89
+
90
+ _ = model.generate(
91
+ input_ids=input_ids,
92
+ max_length=max_length,
93
+ eos_token_id=eos_token_id,
94
+ return_dict_in_generate=True,
95
+ output_scores=True,
96
+ enable_timing=False,
97
+ )
98
+
99
+ start_event = torch.cuda.Event(enable_timing=True)
100
+ end_event = torch.cuda.Event(enable_timing=True)
101
+
102
+ start_event.record()
103
+
104
+ for i in range(args.iterations):
105
+ out = model.generate(
106
+ input_ids=input_ids,
107
+ max_length=max_length,
108
+ eos_token_id=eos_token_id,
109
+ return_dict_in_generate=True,
110
+ output_scores=True,
111
+ enable_timing=False,
112
+ )
113
+
114
+ end_event.record()
115
+
116
+ torch.cuda.synchronize()
117
+
118
+ # print(tokenizer.batch_decode(out.sequences.tolist()))
119
+
120
+ if rank == 0:
121
+ elapsed_time = start_event.elapsed_time(end_event) / args.iterations
122
+ print(f"Average time per genearation : {elapsed_time} ms")
123
+
124
+ # Compute throughput and latency per token
125
+ num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
126
+ throughput = num_tokens_generated / (elapsed_time / 1000) # tokens per second
127
+ latency_per_token = elapsed_time / num_tokens_generated # ms per token
128
+
129
+ print(f"Number of tokens generated: {num_tokens_generated}")
130
+ print(f"Throughput: {throughput} tokens/second")
131
+ print(f"Latency per token: {latency_per_token} ms")
132
+ print(tokenizer.batch_decode(out.sequences.tolist()))
133
+
HybridTensor/benchmarks/generation/opt_generation.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+
4
+ import pytest
5
+ import torch
6
+ import argparse
7
+
8
+ from einops import rearrange
9
+
10
+ from HybridTensor.benchmarks.generation.gen_util import tokenize_dataset, get_random_batch
11
+ from HybridTensor.utils.activations import OPT_MODELS
12
+ from datasets import load_dataset
13
+
14
+ from flash_attn.models.gpt import GPTLMHeadModel
15
+ from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
16
+ from flash_attn.utils.generation import update_graph_cache
17
+ from flash_attn.utils.pretrained import state_dict_from_pretrained
18
+ from transformers import AutoTokenizer, OPTConfig
19
+ from transformers.models.opt.modeling_opt import OPTForCausalLM
20
+
21
+ def test_opt_generation(model_name):
22
+ """Check that our implementation of OPT generation matches the HF implementation:
23
+ the scores in fp16 should be around the same as the HF scores in fp16, when compared to
24
+ the HF scores in fp32.
25
+ """
26
+ print(f"\nMODEL: {model_name}")
27
+ verbose = False
28
+ dtype = torch.float16
29
+ device = "cuda"
30
+ rtol, atol = 3e-3, 3e-1
31
+ config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
32
+ # Only prenorm supports residual_in_fp32
33
+ config.residual_in_fp32 = getattr(config, "prenorm", True)
34
+ config.use_flash_attn = True
35
+ config.fused_bias_fc = True
36
+ config.fused_mlp = True
37
+ config.fused_dropout_add_ln = True
38
+
39
+ model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
40
+ model.eval()
41
+
42
+ torch.manual_seed(0)
43
+ # OPT tokenizer requires use_fast=False
44
+ # https://huggingface.co/docs/transformers/model_doc/opt
45
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
46
+ eos_token_id = tokenizer.eos_token_id
47
+
48
+ input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
49
+ device=device
50
+ )
51
+ max_length = 25
52
+ # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
53
+ # max_length = input_ids.shape[1] + 40
54
+
55
+ # Slow generation for reference
56
+ sequences = []
57
+ scores = []
58
+ cur_input_ids = input_ids
59
+ with torch.inference_mode():
60
+ scores.append(model(cur_input_ids).logits[:, -1])
61
+ sequences.append(scores[-1].argmax(dim=-1))
62
+ for _ in range(input_ids.shape[1] + 1, max_length):
63
+ cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
64
+ scores.append(model(cur_input_ids).logits[:, -1])
65
+ sequences.append(scores[-1].argmax(dim=-1))
66
+ if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
67
+ break
68
+ sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
69
+ scores = tuple(scores)
70
+
71
+ print("Without CUDA graph")
72
+ torch.cuda.synchronize()
73
+ start = time.time()
74
+ out = model.generate(
75
+ input_ids=input_ids,
76
+ max_length=max_length,
77
+ eos_token_id=eos_token_id,
78
+ return_dict_in_generate=True,
79
+ output_scores=True,
80
+ enable_timing=True,
81
+ )
82
+ torch.cuda.synchronize()
83
+ print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
84
+ if verbose:
85
+ print(out.sequences)
86
+ print(tokenizer.batch_decode(out.sequences.tolist()))
87
+ if getattr(config, "use_flash_attn", False):
88
+ # Capture graph outside the timing loop
89
+ batch_size, seqlen_og = input_ids.shape
90
+ model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
91
+ print("With CUDA graph")
92
+ torch.cuda.synchronize()
93
+ start = time.time()
94
+ out_cg = model.generate(
95
+ input_ids=input_ids,
96
+ max_length=max_length,
97
+ cg=True,
98
+ return_dict_in_generate=True,
99
+ output_scores=True,
100
+ enable_timing=True,
101
+ )
102
+ torch.cuda.synchronize()
103
+ print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
104
+ if verbose:
105
+ print(out_cg.sequences)
106
+ print(tokenizer.batch_decode(out_cg.sequences.tolist()))
107
+
108
+ del model
109
+
110
+ model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
111
+ model_hf.eval()
112
+ print("HF fp16")
113
+ torch.cuda.synchronize()
114
+ start = time.time()
115
+ out_hf = model_hf.generate(
116
+ input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
117
+ )
118
+ torch.cuda.synchronize()
119
+ print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
120
+ del model_hf
121
+
122
+ model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
123
+ model_ref.eval()
124
+ print("HF fp32")
125
+ torch.cuda.synchronize()
126
+ start = time.time()
127
+ out_ref = model_ref.generate(
128
+ input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
129
+ )
130
+ torch.cuda.synchronize()
131
+ print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
132
+ del model_ref
133
+ print(tokenizer.batch_decode(out_ref.sequences.tolist()))
134
+
135
+ if verbose:
136
+ print(
137
+ f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
138
+ )
139
+ print(
140
+ f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
141
+ )
142
+ print(
143
+ f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
144
+ )
145
+ print(
146
+ f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
147
+ )
148
+
149
+ assert torch.all(out.sequences == sequences)
150
+ assert torch.allclose(
151
+ torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
152
+ )
153
+ assert torch.all(out.sequences == out_ref.sequences)
154
+ assert torch.all(out.sequences == out_hf.sequences)
155
+
156
+ assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (
157
+ torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
158
+ ).abs().max().item()
159
+
160
+
161
+ def arg_parser():
162
+ parser = argparse.ArgumentParser(description='Inference benchmarking')
163
+ parser.add_argument('--batch_size', type=int, default=32)
164
+ parser.add_argument('--model_index', type=int, default=5)
165
+ parser.add_argument('--seq_len', type=int, default=1024)
166
+ parser.add_argument('--index_size', type=int, default=8192)
167
+ parser.add_argument('--head_density', type=float, default=0.25)
168
+ parser.add_argument('--print_results', type=bool, default=False)
169
+ parser.add_argument('--iterations', type=int, default=1)
170
+ parser.add_argument('--check_results', type=bool, default=False)
171
+ parser.add_argument('--results_dir', type=str, default='results')
172
+ parser.add_argument('--gpu', type=int, default=0)
173
+
174
+ return parser.parse_args()
175
+
176
+ if __name__ == "__main__":
177
+
178
+ args = arg_parser()
179
+ model_name = OPT_MODELS[args.model_index-1]
180
+ # test_opt_generation(model_name)
181
+
182
+ print(f"\nMODEL: {model_name}\n")
183
+ verbose = False
184
+ dtype = torch.float16
185
+ device = "cuda"
186
+ rtol, atol = 3e-3, 3e-1
187
+ config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
188
+ # Only prenorm supports residual_in_fp32
189
+ config.residual_in_fp32 = getattr(config, "prenorm", True)
190
+ config.use_flash_attn = True
191
+ config.fused_bias_fc = True
192
+ config.fused_mlp = True
193
+ config.fused_dropout_add_ln = True
194
+
195
+ model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
196
+ model.eval()
197
+
198
+ torch.manual_seed(0)
199
+ # OPT tokenizer requires use_fast=False
200
+ # https://huggingface.co/docs/transformers/model_doc/opt
201
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
202
+ eos_token_id = tokenizer.eos_token_id
203
+
204
+ # input_ids = tokenizer("In a distant galaxy, a spaceship", return_tensors="pt").input_ids.to(
205
+ # device=device
206
+ # )
207
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
208
+
209
+ tokens = tokenize_dataset(dataset, tokenizer)
210
+ input_ids = get_random_batch(tokens, args.batch_size, args.seq_len)
211
+ input_ids = input_ids.to(device=device)
212
+ max_length = args.seq_len + 20
213
+
214
+ # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
215
+ # max_length = input_ids.shape[1] + 40
216
+
217
+ # warm up
218
+ _ = model.generate(
219
+ input_ids=input_ids,
220
+ max_length=max_length,
221
+ eos_token_id=eos_token_id,
222
+ return_dict_in_generate=True,
223
+ output_scores=True,
224
+ enable_timing=False,
225
+ )
226
+
227
+ print("Without CUDA graph")
228
+ torch.cuda.synchronize()
229
+ start = time.time()
230
+ out = model.generate(
231
+ input_ids=input_ids,
232
+ max_length=max_length,
233
+ eos_token_id=eos_token_id,
234
+ return_dict_in_generate=True,
235
+ output_scores=True,
236
+ enable_timing=False,
237
+ )
238
+ torch.cuda.synchronize()
239
+ elapsed_time = (time.time() - start) * 1000
240
+ print(f"Prompt processing + decoding time: {elapsed_time:.0f} ms")
241
+
242
+ # Compute throughput and latency per token
243
+ num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
244
+ throughput = (args.batch_size * num_tokens_generated) / (elapsed_time / 1000)
245
+ latency_per_token = elapsed_time / num_tokens_generated # ms per token
246
+
247
+ # print(f"Number of tokens generated: {num_tokens_generated}")
248
+ print(f"Throughput: {throughput:.1f} tokens/second")
249
+ print(f"Latency per token: {latency_per_token:.1f} ms")
250
+
251
+
252
+ if args.print_results:
253
+ # print(out.sequences)
254
+ print(tokenizer.batch_decode(out.sequences.tolist()))
255
+
256
+ # ============================================================================= #
257
+
258
+ print("\n")
259
+ if getattr(config, "use_flash_attn", False):
260
+ # Capture graph outside the timing loop
261
+ batch_size, seqlen_og = input_ids.shape
262
+ model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
263
+ print("With CUDA graph")
264
+ torch.cuda.synchronize()
265
+ start = time.time()
266
+ out_cg = model.generate(
267
+ input_ids=input_ids,
268
+ max_length=max_length,
269
+ cg=True,
270
+ return_dict_in_generate=True,
271
+ output_scores=True,
272
+ enable_timing=False,
273
+ )
274
+ torch.cuda.synchronize()
275
+ elapsed_time = (time.time() - start) * 1000
276
+ print(f"Prompt processing + decoding time: {elapsed_time:.0f} ms")
277
+
278
+ # Compute throughput and latency per token
279
+ num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
280
+ latency_per_token = elapsed_time / num_tokens_generated # ms per token
281
+ throughput = (args.batch_size * num_tokens_generated) / (elapsed_time / 1000)
282
+
283
+ # print(f"Number of tokens generated: {num_tokens_generated}")
284
+ print(f"Throughput: {throughput:.1f} tokens/second")
285
+ print(f"Latency per token: {latency_per_token:.1f} ms")
286
+
287
+ if args.print_results:
288
+ # print(out_cg.sequences)
289
+ print(tokenizer.batch_decode(out_cg.sequences.tolist()))
HybridTensor/benchmarks/generation/opt_sparse_gen_tp.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.opt import OPTConfig
2
+ from transformers import AutoTokenizer
3
+ from flash_attn.models.opt import opt_config_to_gpt2_config
4
+
5
+ import os
6
+ import torch
7
+ import argparse
8
+ from apex.transformer import parallel_state
9
+
10
+ from HybridTensor.utils.utils import arg_parser, _get_device
11
+ from HybridTensor.utils.activations import OPT_MODELS
12
+ from HybridTensor.models.opt import SparseConfig, build_sparse_opt
13
+
14
+ def update_router_config(model, num_layers, mlp_act_th, attn_topk, layer_config = None):
15
+ for i in range(num_layers):
16
+ model.transformer.layers[i].mlp_router.act_th = mlp_act_th
17
+ model.transformer.layers[i].mha_router.topk = attn_topk
18
+
19
+ def initialize_distributed_environment():
20
+ # Set environment variables for NCCL
21
+ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
22
+ os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0"
23
+
24
+ # Initialize the distributed process group
25
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
26
+
27
+ # Set the device based on the rank of the current process
28
+ device = f"cuda:{torch.distributed.get_rank()}"
29
+ world_size = torch.distributed.get_world_size()
30
+
31
+ # Set the current CUDA device to avoid operations being executed on the wrong GPU
32
+ torch.cuda.set_device(device)
33
+
34
+ # You can return device, world_size, and any other relevant information
35
+ return device, world_size
36
+
37
+
38
+ def arg_parser():
39
+ parser = argparse.ArgumentParser(description='Inference benchmarking')
40
+ parser.add_argument('--batch_size', type=int, default=128)
41
+ parser.add_argument('--model_index', type=int, default=5)
42
+ parser.add_argument('--seq_len', type=int, default=28)
43
+ parser.add_argument('--index_size', type=int, default=8192)
44
+ parser.add_argument('--head_density', type=float, default=0.25)
45
+ parser.add_argument('--print_results', type=bool, default=True)
46
+ parser.add_argument('--iterations', type=int, default=100)
47
+ parser.add_argument('--check_results', type=bool, default=False)
48
+ parser.add_argument('--results_dir', type=str, default='results')
49
+ parser.add_argument('--gpu', type=int, default=0)
50
+ parser.add_argument('--mlp_ckpt_dir', type=str, default='<PATH_TO_MLP_ROUTER_CHECKPOINTS>')
51
+ parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model')
52
+ parser.add_argument('--attn_ckpt_dir', type=str, default='<PATH_TO_ATTENTION_CHECKPOINTS>')
53
+
54
+ return parser.parse_args()
55
+
56
+ if __name__ == "__main__":
57
+
58
+ args = arg_parser()
59
+ model_name = OPT_MODELS[args.model_index-1]
60
+
61
+ device, world_size = initialize_distributed_environment()
62
+ dtype = torch.float16
63
+
64
+ parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
65
+ rank = parallel_state.get_tensor_model_parallel_rank()
66
+ process_group = parallel_state.get_tensor_model_parallel_group()
67
+
68
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
69
+ model = build_sparse_opt(model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype, process_group = process_group, world_size = world_size, rank = rank)
70
+ model.eval()
71
+ print("Model loaded with sparse routers")
72
+
73
+ mlp_act_th = 0.5
74
+ attn_topk = 0.5
75
+
76
+ update_router_config(model, model.config.n_layer, mlp_act_th, attn_topk)
77
+ print("Router config updated")
78
+
79
+ # print router configs from all layers
80
+ # for i in range(model.config.n_layer):
81
+ # print(f"Layer {i}: mlp_act_th = {model.transformer.layers[i].mlp_router.act_th}, attn_topk = {model.transformer.layers[i].mha_router.topk}")
82
+
83
+ input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "]
84
+ # input_texts = ["Hello, my dog is cute and", "Hello, my rat is cute and"]
85
+
86
+ tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(device)
87
+ input_ids=tokenized_inputs["input_ids"]
88
+
89
+ # input_ids = tokenizer("Hello, my dog is cute and", return_tensors="pt").input_ids.to(device=device)
90
+ max_length = args.seq_len
91
+ position_ids = None
92
+ eos_token_id = tokenizer.eos_token_id
93
+ num_layers = model.config.n_layer
94
+
95
+ # print all the model weights and check the accuracy
96
+ # if rank == 0:
97
+ # print(model.state_dict())
98
+
99
+ # out = model(input_ids)
100
+ # print(out)
101
+
102
+ out = model.generate(
103
+ input_ids=input_ids,
104
+ max_length=max_length,
105
+ eos_token_id=eos_token_id,
106
+ return_dict_in_generate=True,
107
+ output_scores=True,
108
+ enable_timing=True,
109
+ )
110
+ if rank == 0:
111
+ print(tokenizer.batch_decode(out.sequences.tolist()))
112
+
HybridTensor/benchmarks/generation/opt_sparse_generation.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+
4
+ from HybridTensor.utils.utils import _get_device
5
+ from HybridTensor.utils.activations import OPT_MODELS
6
+ from HybridTensor.models.opt import SparseConfig, build_sparse_opt
7
+ from HybridTensor.benchmarks.generation.gen_util import tokenize_dataset, get_random_batch
8
+ from HybridTensor.utils.activations import build_mlp_topk_lookup
9
+ from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv
10
+
11
+ from datasets import load_dataset
12
+
13
+ from transformers.models.opt import OPTConfig
14
+ from transformers import AutoTokenizer
15
+ from flash_attn.models.opt import opt_config_to_gpt2_config
16
+ from flash_attn.utils.generation import update_graph_cache
17
+
18
+ def arg_parser():
19
+ parser = argparse.ArgumentParser(description='Inference benchmarking')
20
+ parser.add_argument('--batch_size', type=int, default=16)
21
+ parser.add_argument('--model_index', type=int, default=5)
22
+ parser.add_argument('--seq_len', type=int, default=1024)
23
+ parser.add_argument('--index_size', type=int, default=8192)
24
+ parser.add_argument('--head_density', type=float, default=0.5)
25
+ parser.add_argument('--print_results', type=bool, default=True)
26
+ parser.add_argument('--iterations', type=int, default=1)
27
+ parser.add_argument('--check_results', type=bool, default=False)
28
+ parser.add_argument('--results_dir', type=str, default='results')
29
+ parser.add_argument('--gpu', type=int, default=0)
30
+ parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model')
31
+ parser.add_argument('--mlp_ckpt_dir', type=str, default='<PATH_TO_MLP_ROUTER_CHECKPOINTS>')
32
+ parser.add_argument('--attn_ckpt_dir', type=str, default='<PATH_TO_ATTENTION_CHECKPOINTS>')
33
+ parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router/opt-6.7b')
34
+ parser.add_argument('--delta', type=int, default=256, help='Delta value for MLP topk calculation')
35
+ parser.add_argument('--use_cuda_graph', type=bool, default=False, help='Use CUDA graph for inference')
36
+
37
+ return parser.parse_args()
38
+
39
+ def update_router_config(model, num_layers, mlp_topk_lookup, attn_topk):
40
+ for i in range(num_layers):
41
+ model.transformer.layers[i].mlp_topk = mlp_topk_lookup[i]
42
+ # model.transformer.layers[i].mlp_topk = 512
43
+ model.transformer.layers[i].mha_router.topk = attn_topk
44
+
45
+ # model.transformer.layers[i].skip_mlp_router = True
46
+ model.transformer.layers[0].mha_router.topk = 1.0 # dense attention in layer 0
47
+
48
+ if __name__ == "__main__":
49
+ args = arg_parser()
50
+ model_name = OPT_MODELS[args.model_index-1]
51
+ dtype = torch.float16
52
+ device= _get_device(args.gpu)
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+
56
+ # args.mlp_ckpt_dir = None
57
+ # args.attn_ckpt_dir = None
58
+
59
+ model = build_sparse_opt(args, model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype)
60
+ model.eval()
61
+ print(model)
62
+ print("Model loaded with sparse routers")
63
+
64
+ # mlp_topk_lookup = build_mlp_topk_lookup("results/mlp_results/batch_activations/opt-6.7b", args.batch_size, args.delta)
65
+ mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, args.batch_size)
66
+ print("MLP topk values updated: ", mlp_topk_lookup)
67
+ update_router_config(model, model.config.n_layer, mlp_topk_lookup, args.attn_topk) # this sets the router config for all layers using a single config
68
+ # update_router_config(model, model.config.n_layer, 2048, args.attn_topk)
69
+ print("Router config updated \n")
70
+
71
+
72
+ max_length = args.seq_len + 20
73
+ batch_size = args.batch_size
74
+
75
+ # input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "]
76
+ # input_texts = ["In a distant galaxy, a spaceship"]
77
+ # tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=False).to(device)
78
+ # input_ids=tokenized_inputs["input_ids"]
79
+
80
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
81
+ tokens = tokenize_dataset(dataset, tokenizer)
82
+ input_ids = get_random_batch(tokens, args.batch_size, args.seq_len).to(device)
83
+
84
+ print("Input ids generated, starting inference")
85
+
86
+ # input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(device)
87
+ position_ids = None
88
+ eos_token_id = tokenizer.eos_token_id
89
+
90
+ start_event = torch.cuda.Event(enable_timing=True)
91
+ end_event = torch.cuda.Event(enable_timing=True)
92
+
93
+ with torch.no_grad():
94
+ # warm up
95
+ _ = model.generate(
96
+ input_ids=input_ids,
97
+ max_length=max_length,
98
+ eos_token_id=eos_token_id,
99
+ return_dict_in_generate=True,
100
+ output_scores=True,
101
+ enable_timing=False,
102
+ cg=False,
103
+ )
104
+
105
+ print("Warm up done")
106
+
107
+ start_event.record()
108
+ for i in range(args.iterations):
109
+ out = model.generate(
110
+ input_ids=input_ids,
111
+ max_length=max_length,
112
+ eos_token_id=eos_token_id,
113
+ return_dict_in_generate=True,
114
+ output_scores=True,
115
+ enable_timing=False,
116
+ cg=False,
117
+ )
118
+
119
+ end_event.record()
120
+
121
+ torch.cuda.synchronize()
122
+ print("Without CUDA graph")
123
+ elapsed_time = start_event.elapsed_time(end_event) / args.iterations
124
+ print(f"Average time per genearation : {elapsed_time:.1f} ms")
125
+
126
+ # Compute throughput and latency per token
127
+ num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
128
+ throughput = batch_size * num_tokens_generated / (elapsed_time / 1000) # tokens per second
129
+ latency_per_token = elapsed_time / num_tokens_generated # ms per token
130
+
131
+ print(f"Number of tokens generated: {num_tokens_generated}")
132
+ print(f"Throughput: {throughput:.1f} tokens/second")
133
+ print(f"Latency per token: {latency_per_token:.1f} ms")
134
+
135
+ # print(tokenizer.batch_decode(out.sequences.tolist()))
136
+ print("\n")
137
+
138
+ # print only the new tokens generated
139
+ print("New tokens generated:")
140
+ print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist()))
141
+
142
+ # ====================== With CUDA graph ======================
143
+ if args.use_cuda_graph:
144
+ batch_size, seqlen_og = input_ids.shape
145
+ model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
146
+ print("With CUDA graph")
147
+ torch.cuda.synchronize()
148
+
149
+ start_event.record()
150
+
151
+ for i in range(args.iterations):
152
+ out = model.generate(
153
+ input_ids=input_ids,
154
+ max_length=max_length,
155
+ cg=True,
156
+ eos_token_id=eos_token_id,
157
+ return_dict_in_generate=True,
158
+ output_scores=True,
159
+ enable_timing=False,
160
+ )
161
+
162
+ end_event.record()
163
+
164
+ torch.cuda.synchronize()
165
+
166
+
167
+ elapsed_time = start_event.elapsed_time(end_event) / args.iterations
168
+ print(f"Average time per genearation : {elapsed_time:.1f} ms")
169
+
170
+ # Compute throughput and latency per token
171
+ num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
172
+ throughput = batch_size * num_tokens_generated / (elapsed_time / 1000) # tokens per second
173
+ latency_per_token = elapsed_time / num_tokens_generated # ms per token
174
+
175
+ print(f"Number of tokens generated: {num_tokens_generated}")
176
+ print(f"Throughput: {throughput:.1f} tokens/second")
177
+ print(f"Latency per token: {latency_per_token:.1f} ms")
178
+
179
+ # print(tokenizer.batch_decode(out.sequences.tolist()))
180
+ print("New tokens generated:")
181
+ print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist()))
182
+
HybridTensor/benchmarks/model_eval.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import os
4
+ import json
5
+ import logging
6
+ import numpy as np
7
+ import csv
8
+
9
+ # from hf_models.opt.modeling_opt_routers import (
10
+ # SparseOPTForCausalLM,
11
+ # create_hf_mha_router_state_dict,
12
+ # create_hf_mlp_router_state_dict
13
+ # )
14
+
15
+ from hf_models.opt.modeling_opt_routers_topk import (
16
+ SparseOPTForCausalLM,
17
+ create_hf_mha_router_state_dict,
18
+ create_hf_mlp_router_state_dict
19
+ )
20
+
21
+ from hf_models.llama.modeling_sparse_llama_routers import (
22
+ SparseLlamaForCausalLM,
23
+ create_hf_attn_router_state_dict
24
+ )
25
+
26
+ from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM as SparseOPTTopKAttn
27
+ from hf_models.llama.modeling_sparse_llama_mha_topk import SparseLlamaForCausalLM as SparseLlamaTopKAttn
28
+ from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import _update_model_attn_thresholds
29
+ from HybridTensor.benchmarks.model_perplexity import compute_attn_layer_sparsity, compute_average_activation
30
+ from HybridTensor.utils.activations import ActivationThresholds, build_mlp_topk_lookup, _update_hf_mlp_topk, CONFIGS, MODELS
31
+ from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv
32
+ from HybridTensor.utils.utils import extract_model_name
33
+
34
+ from transformers import AutoTokenizer, AutoModelForCausalLM
35
+
36
+ from lm_eval.models.huggingface import HFLM
37
+ from lm_eval.tasks import TaskManager
38
+ import lm_eval
39
+
40
+ import pandas as pd
41
+ from tabulate import tabulate
42
+
43
+
44
+ import logging
45
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
46
+
47
+ import warnings
48
+ warnings.simplefilter(action='ignore', category=FutureWarning)
49
+
50
+
51
+ from huggingface_hub import login
52
+
53
+ def read_and_print_results(filepath='results.csv'):
54
+ """
55
+ Reads the CSV file containing evaluation results and prints them in a formatted table.
56
+ """
57
+ if not os.path.exists(filepath):
58
+ print(f"File '{filepath}' not found.")
59
+ return
60
+
61
+ df = pd.read_csv(filepath)
62
+ print(tabulate(df, headers='keys', tablefmt='psql', showindex=False))
63
+
64
+ def save_results_to_csv(results, attn_topk, filepath='eval_results.csv'):
65
+ """
66
+ Extracts benchmark accuracies from results and saves them along with the attn_topk config.
67
+
68
+ Parameters:
69
+ results: dict, evaluation results with structure results['results'][<benchmark>]['acc,none']
70
+ attn_topk: float, the attention top-k value used for this run
71
+ filepath: str, CSV file to write to (appends if it exists)
72
+ """
73
+ # Build a dictionary row with attn_topk and each benchmark's accuracy
74
+ row = {'attn_topk': attn_topk}
75
+ for benchmark, data in results['results'].items():
76
+ # Default to None if the key is missing
77
+ row[benchmark] = data.get('acc,none', None)
78
+
79
+ # Check if file exists to decide on writing header
80
+ file_exists = os.path.isfile(filepath)
81
+ with open(filepath, 'a', newline='') as csvfile:
82
+ writer = csv.DictWriter(csvfile, fieldnames=row.keys())
83
+ if not file_exists:
84
+ writer.writeheader()
85
+ writer.writerow(row)
86
+
87
+ def _update_model_attn_sparsity(model, attn_th):
88
+ num_layers = model.config.num_hidden_layers
89
+
90
+ # Use the 'decoder' attribute if it exists; otherwise use model.model.layers
91
+ layers = model.model.decoder.layers if hasattr(model.model, 'decoder') else model.model.layers
92
+ attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=attn_th)
93
+
94
+ for i in range(num_layers):
95
+ layers[i].self_attn.sp_threshold = attn_sparsity_map[i]
96
+
97
+ average_act = compute_average_activation(attn_sparsity_map)
98
+ print(f"Attention sparsity {attn_th}: {attn_sparsity_map}")
99
+ print(f"Average activation: {average_act:.2f}")
100
+
101
+ return model
102
+
103
+ def _evaluate_model(model, tokenizer, benchmarks: list, device: str, batch_size: int = 8):
104
+ logging.info("Evaluating on benchmarks: %s", benchmarks)
105
+ lm_obj = HFLM(
106
+ pretrained=model,
107
+ tokenizer=tokenizer,
108
+ device=device,
109
+ batch_size=batch_size
110
+ )
111
+ task_manager = TaskManager()
112
+ num_fewshot = 5
113
+ print(f"Number of fewshot examples: {num_fewshot}")
114
+ results = lm_eval.simple_evaluate(
115
+ model=lm_obj,
116
+ tasks=benchmarks,
117
+ num_fewshot=num_fewshot, # change this
118
+ task_manager=task_manager
119
+ )
120
+ logging.info("Evaluation complete.")
121
+ for benchmark, benchmark_results in results['results'].items():
122
+ logging.info("Results for %s: %s", benchmark.upper(), benchmark_results)
123
+ return results
124
+
125
+ def _load_model(model_name, num_layers, device, args):
126
+ if args.mode == 'sparse':
127
+ logging.info("Loading sparse model...")
128
+ sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th= args.attn_topk, mlp_th=args.mlp_topk)
129
+
130
+ if args.model_index <=8:
131
+ # OPT models
132
+ model = SparseOPTForCausalLM.from_pretrained(
133
+ model_name,
134
+ device_map=device,
135
+ torch_dtype=torch.float16,
136
+ sp_thresholds=sp_thresholds.activation_threshold,
137
+ mlp_thresholds=sp_thresholds.mlp_threshold,
138
+ attn_implementation="flash_attention_2"
139
+ )
140
+ logging.info("Loading router states...")
141
+ mlp_router_state = create_hf_mlp_router_state_dict(args.mlp_ckpt_dir)
142
+ mha_router_state = create_hf_mha_router_state_dict(args.attn_ckpt_dir)
143
+ model_state = model.state_dict()
144
+ model_state.update(mlp_router_state)
145
+ model_state.update(mha_router_state)
146
+ model.load_state_dict(model_state)
147
+ logging.info("Sparse model loaded with routers!")
148
+
149
+ # load topk values for mlp and attn here
150
+ # mlp_topk_lookup = build_mlp_topk_lookup(args.batch_stats_dir, args.batch_size, args.delta)
151
+ # mlp_topk_lookup = build_mlp_topk_lookup(args.batch_stats_dir, 1, args.delta)
152
+ mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, 1)
153
+
154
+ _update_hf_mlp_topk(model, mlp_topk_lookup)
155
+ # print("MLP topk values updated.")
156
+ # print("MLP topk values: ", mlp_topk_lookup)
157
+ logging.info("Using MLP topk values: %s", mlp_topk_lookup)
158
+
159
+ # print("Using delta value: ", args.delta)
160
+
161
+ # the first layer should use dense attention
162
+ model.model.decoder.layers[0].self_attn.sp_threshold = 1.0
163
+ else:
164
+ # Llama models
165
+
166
+ if not args.static_thresholds:
167
+ attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=args.attn_topk)
168
+ sp_thresholds.load_thresholds(attn_sparsity_map)
169
+ average_act = compute_average_activation(attn_sparsity_map)
170
+ print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}")
171
+ print(f"Average activation: {average_act:.2f}")
172
+
173
+ model = SparseLlamaForCausalLM.from_pretrained(model_name,
174
+ device_map = device,
175
+ torch_dtype=torch.float16,
176
+ sp_thresholds = sp_thresholds.activation_threshold,
177
+ attn_implementation="flash_attention_2")
178
+ logging.info("Loading router states...")
179
+ model_state = model.state_dict()
180
+ attn_router_states = create_hf_attn_router_state_dict(args.attn_ckpt_dir)
181
+ model_state.update(attn_router_states)
182
+ model.load_state_dict(model_state)
183
+ logging.info("Sparse model loaded with routers!")
184
+
185
+ # the first layer should use dense attetnion
186
+ _update_model_attn_thresholds(model, args.attn_topk)
187
+
188
+ # load topk values for mha here
189
+ # TODO: create a function to update the topk values for mha
190
+
191
+ elif args.mode == 'sparse_attn':
192
+ logging.info("Loading model with sparse attention")
193
+ sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=args.attn_topk)
194
+
195
+ if not args.static_thresholds:
196
+ attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=args.attn_topk)
197
+ sp_thresholds.load_thresholds(attn_sparsity_map)
198
+ average_act = compute_average_activation(attn_sparsity_map)
199
+ print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}")
200
+ print(f"Average activation: {average_act:.2f}")
201
+
202
+ if args.model_index <= 8:
203
+ # opt models
204
+ model = SparseOPTTopKAttn.from_pretrained(model_name, device_map = device, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
205
+ else:
206
+ # llama models
207
+ model = SparseLlamaTopKAttn.from_pretrained(model_name, device_map = device, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
208
+ else:
209
+ logging.info("Loading dense model...")
210
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.float16)
211
+ return model
212
+
213
+ def arg_parser():
214
+ parser = argparse.ArgumentParser(description='Inference benchmarking')
215
+ parser.add_argument('--batch_size', type=int, default=8)
216
+ parser.add_argument('--model_index', type=int, default=5)
217
+ parser.add_argument('--print_results', type=bool, default=True)
218
+ parser.add_argument('--results_dir', type=str, default='results/eval')
219
+ parser.add_argument('--device', type=int, default=100)
220
+ parser.add_argument('--mode', type=str, default='sparse', choices=['sparse', 'dense', 'sparse_attn'])
221
+ parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model')
222
+ parser.add_argument('--mlp_topk', type=int, default=2048, help='MLP topk for sparse model')
223
+ parser.add_argument('--delta', type=int, default=128, help='Delta value for MLP topk calculation')
224
+ parser.add_argument('--mlp_ckpt_dir', type=str, default='<PATH_TO_MLP_ROUTER_CHECKPOINTS>')
225
+ parser.add_argument('--attn_ckpt_dir', type=str, default='<PATH_TO_ATTENTION_CHECKPOINTS>')
226
+ parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router')
227
+ parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds')
228
+ parser.add_argument('--benchmark', type=str, default='all', help='Options: all, or a single benchmark name')
229
+ parser.add_argument('--note', type=str, default='', help='Note to add to the results filename')
230
+ parser.add_argument('--static_thresholds', type=bool, default=True, help='Use static thresholds for attention layers')
231
+ return parser.parse_args()
232
+
233
+ if __name__ == "__main__":
234
+ args = arg_parser()
235
+
236
+ login_token = None # insert your token here
237
+ assert login_token is not None, "Please provide a valid Hugging Face token."
238
+ login(token=login_token)
239
+
240
+ # Setup logging
241
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
242
+
243
+ model_name = MODELS[args.model_index - 1]
244
+ # print(f"Evaluating Model: {model_name}")
245
+ logging.info("Evaluating Model: %s", model_name)
246
+ logging.info("Mode: %s", args.mode)
247
+
248
+ num_layers = CONFIGS[model_name]['num_layer']
249
+ device = 'auto' if args.device == 100 else f'cuda:{args.device}'
250
+
251
+ # Load tokenizer and model
252
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
253
+ model = _load_model(model_name, num_layers, device, args)
254
+ model.eval()
255
+
256
+ # Determine benchmarks to evaluate
257
+ if args.benchmark == 'all':
258
+ benchmarks = ["piqa", "winogrande", "copa", "rte", "openbookqa", "arc_easy", "arc_challenge", "mmlu", "hellaswag"]
259
+ else:
260
+ benchmarks = [args.benchmark]
261
+
262
+ model_name_clean = extract_model_name(model_name)
263
+
264
+ if args.data_collection:
265
+ # make sure the model is not dense
266
+ assert args.mode != 'dense', "Data collection is only available for sparse models"
267
+ logging.info("Data collection mode enabled.")
268
+ if args.mode == 'sparse':
269
+ filepath = f"{args.results_dir}/eval_results_{model_name_clean}_sparse_sweep_dpsd.csv"
270
+ else: # sparse_attn
271
+ filepath = f"{args.results_dir}/eval_results_{model_name_clean}_attn_sweep_dpsd.csv"
272
+
273
+ if args.note != '':
274
+ filepath = filepath.replace('.csv', f"_{args.note}.csv")
275
+ # attn_topk_values = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] # MHA
276
+ attn_topk_values = [0.9, 0.8, 0.7, 0.6, 0.4, 0.3, 0.2, 0.1]
277
+ # attn_topk_values = [7/8, 6/8, 5/8, 4/8, 3/8, 2/8, 1/8] # GQA
278
+ for attn_topk in attn_topk_values:
279
+ logging.info("Evaluating with attention top-k value: %s", attn_topk)
280
+ if args.static_thresholds:
281
+ _update_model_attn_thresholds(model, attn_topk, mode=args.mode)
282
+ else:
283
+ _update_model_attn_sparsity(model, attn_topk)
284
+
285
+ results = _evaluate_model(
286
+ model=model,
287
+ tokenizer=tokenizer,
288
+ benchmarks=benchmarks,
289
+ device=device,
290
+ batch_size=args.batch_size
291
+ )
292
+ save_results_to_csv(results, attn_topk, filepath = filepath)
293
+ else:
294
+ logging.info("Evaluating with attention top-k value: %s", args.attn_topk)
295
+ if args.mode == 'dense':
296
+ filepath = f"{args.results_dir}/eval_results_{model_name_clean}_dense.csv"
297
+ elif args.mode == 'sparse_attn':
298
+ filepath = f"{args.results_dir}/eval_results_{model_name_clean}_sparse_attn_{args.attn_topk}_dpsd.csv"
299
+ else:
300
+ filepath = f"{args.results_dir}/eval_results_{model_name_clean}_test_attn_{args.attn_topk}_dpsd.csv"
301
+ if args.note != '':
302
+ filepath = filepath.replace('.csv', f"_{args.note}.csv")
303
+ results = _evaluate_model(
304
+ model=model,
305
+ tokenizer=tokenizer,
306
+ benchmarks=benchmarks,
307
+ device=device,
308
+ batch_size=args.batch_size
309
+ )
310
+ save_results_to_csv(results, args.attn_topk, filepath = filepath)
311
+
312
+ if args.print_results:
313
+ read_and_print_results(filepath=filepath)
HybridTensor/benchmarks/model_perplexity.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python -m HybridTensor.benchmarks.model_perplexity --model_index 14 --batch_size 4 --max_length 512 --attn_th 1 --static_thresholds True
2
+
3
+ import sys
4
+ import math
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
7
+
8
+ from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM as SparseOPTTopkAttn
9
+ from hf_models.llama.modeling_sparse_llama_mha_topk import SparseLlamaForCausalLM as SparseLlamaTopKAttn
10
+ from HybridTensor.utils.activations import ActivationThresholds, identify_model_type, MODELS, CONFIGS
11
+ from HybridTensor.utils.utils import extract_model_name, compute_perplexity
12
+ import argparse
13
+ from datasets import load_dataset
14
+ import json
15
+ from torch.utils.data import DataLoader
16
+ from tqdm import tqdm
17
+ import pandas as pd
18
+
19
+
20
+ from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import (_update_model_attn_thresholds,
21
+ build_data_loader,
22
+ compute_sparse_perplexity,
23
+ compute_perplexity_data_collection,
24
+ display_model_menu,
25
+ _interactive_mode,
26
+ arg_parser,
27
+ )
28
+
29
+
30
+ results_dir = "results/activations"
31
+
32
+ def compute_attn_layer_sparsity(model_name, min_th, critical_th, attn_sparsity):
33
+ # Get model configuration
34
+ # model_name = MODELS[model_index - 1]
35
+ model_config = CONFIGS[model_name]
36
+ num_layers = model_config['num_layer']
37
+
38
+ # Load the importance scores from the file specified in the configuration
39
+ file_path = model_config['layer_imp']
40
+ with open(file_path, 'r') as f:
41
+ attn_layer_imp = json.load(f)
42
+ layer_importance = attn_layer_imp['importance_scores']
43
+
44
+ # Classify layers as critical or sparse
45
+ critical_layers = [i for i, imp in enumerate(layer_importance) if imp >= critical_th]
46
+ sparse_layers = [i for i, imp in enumerate(layer_importance) if imp < critical_th]
47
+
48
+ # Calculate total sparse importance and the attention value
49
+ sum_sparse_importance = sum(layer_importance[i] for i in sparse_layers)
50
+ attn_val = attn_sparsity * len(sparse_layers)
51
+
52
+ # Compute the sparsity map per layer
53
+ layer_sparsity_map = {}
54
+ for layer_idx in range(num_layers):
55
+ if layer_idx in critical_layers:
56
+ layer_sparsity_map[layer_idx] = 1.0 # Fully dense for critical layers
57
+ else:
58
+ if sum_sparse_importance > 0:
59
+ raw_fraction = (layer_importance[layer_idx] / sum_sparse_importance) * attn_val
60
+ else:
61
+ raw_fraction = attn_sparsity
62
+ # Clamp the fraction between min_th and 1.0
63
+ fraction = max(raw_fraction, min_th)
64
+ fraction = min(fraction, 1.0)
65
+ layer_sparsity_map[layer_idx] = fraction
66
+
67
+ return layer_sparsity_map
68
+
69
+ def compute_average_activation(layer_sparsity_map):
70
+ """
71
+ Computes the average activation for each layer based on the sparsity map.
72
+ """
73
+ total_activation = 0.0
74
+ for layer_idx, fraction in layer_sparsity_map.items():
75
+ total_activation += fraction
76
+
77
+ average_activation = total_activation / len(layer_sparsity_map)
78
+ return average_activation
79
+
80
+ def compute_sparse_perplexity(model_name='facebook/opt-125m',
81
+ dataset_name='wikitext',
82
+ dataset_config='wikitext-2-raw-v1',
83
+ batch_size=8,
84
+ max_length=512,
85
+ attn_th=0.0,
86
+ static_thresholds=True,
87
+ device_map="cuda:0"):
88
+ # Set device
89
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90
+ print(f'Using device: {device}')
91
+
92
+ # load the activation thresholds
93
+ num_layers = CONFIGS[model_name]['num_layer']
94
+ sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=attn_th)
95
+
96
+ print(f"Static attention activations: {sp_thresholds.activation_threshold}")
97
+ if not static_thresholds:
98
+ # act_threshold_filepath = CONFIGS[model_name]['sp_config']
99
+ attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=attn_th)
100
+ sp_thresholds.load_thresholds(attn_sparsity_map)
101
+ average_act = compute_average_activation(attn_sparsity_map)
102
+ print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}")
103
+ print(f"Average activation: {average_act:.2f}")
104
+
105
+ # Load tokenizer and model
106
+ # tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
107
+ model_type = identify_model_type(model_name)
108
+ if model_type == 'OPT':
109
+ print(f"Loading OPT model: {model_name}")
110
+ model = SparseOPTTopkAttn.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
111
+ elif model_type == 'Llama':
112
+ print(f"Loading Llama model: {model_name}")
113
+ model = SparseLlamaTopKAttn.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
114
+ model.eval()
115
+
116
+ # # Load dataset
117
+ dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length)
118
+ perplexity = compute_perplexity(model, dataloader, device)
119
+ return perplexity
120
+
121
+
122
+ def arg_parser():
123
+ parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation')
124
+ parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate')
125
+ parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation')
126
+ parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length')
127
+ parser.add_argument('--attn_th', type=float, default=0.0, help='Activation threshold for attention layers')
128
+ parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds')
129
+ parser.add_argument('--device_map', type=str, default='auto', help='Device to use for evaluation')
130
+ parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection')
131
+ parser.add_argument('--static_thresholds', type=bool, default=False, help='Use static thresholds for attention layers')
132
+
133
+ return parser.parse_args()
134
+
135
+ def main():
136
+ """
137
+ Main function to execute the perplexity computation with user-selected OPT model.
138
+ """
139
+ print("=== OPT Models Perplexity Evaluation ===\n")
140
+ args = arg_parser()
141
+
142
+ if args.interactive:
143
+ selected_model, batch_size, max_length, data_collection, device_map, attn_th = _interactive_mode()
144
+
145
+ else:
146
+ selected_model, batch_size, max_length, data_collection, device_map, attn_th = MODELS[args.model_index-1], args.batch_size, args.max_length, args.data_collection, args.device_map, args.attn_th
147
+ print(f"Selected model: {selected_model}, batch size: {batch_size}, max length: {max_length}, attn_th: {attn_th}, data_collection: {data_collection}, device: {device_map}")
148
+
149
+ if data_collection:
150
+ print("\nStarting data collection...\n")
151
+ compute_perplexity_data_collection(model_name=selected_model, batch_size=batch_size, max_length=max_length, device_map=device_map)
152
+ print("\nData collection complete.\n")
153
+
154
+ else:
155
+ print("\nStarting perplexity computation...\n")
156
+ perplexity = compute_sparse_perplexity(model_name=selected_model, batch_size=batch_size, max_length=max_length,
157
+ attn_th=attn_th,
158
+ device_map=device_map,
159
+ static_thresholds=args.static_thresholds)
160
+ print(f"\n=== Perplexity Results ===")
161
+ print(f"Model: {selected_model}")
162
+ print(f"Perplexity: {perplexity:.2f}\n")
163
+
164
+ if __name__ == "__main__":
165
+ main()
HybridTensor/benchmarks/opt_attn_sparse_topk_perplexity.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import math
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
6
+ # from hf_models.opt.modeling_sparse_opt import SparseOPTForCausalLM
7
+ from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM
8
+ from HybridTensor.utils.activations import ActivationThresholds, MODELS, CONFIGS
9
+ from HybridTensor.utils.utils import extract_model_name, compute_perplexity
10
+
11
+ import argparse
12
+ from datasets import load_dataset
13
+
14
+ from torch.utils.data import DataLoader
15
+ from tqdm import tqdm
16
+ import pandas as pd
17
+
18
+ results_dir = "results/activations"
19
+
20
+ def _update_model_attn_thresholds(model, attn_th, mode='sparse'):
21
+ num_layers = model.config.num_hidden_layers
22
+
23
+ # Use the 'decoder' attribute if it exists; otherwise use model.model.layers
24
+ layers = model.model.decoder.layers if hasattr(model.model, 'decoder') else model.model.layers
25
+
26
+ for i in range(num_layers):
27
+ layers[i].self_attn.sp_threshold = attn_th
28
+
29
+ # For non-sparse attention, layer 0 should use a threshold of 1.0
30
+ # if mode != 'sparse_attn':
31
+ # layers[0].self_attn.sp_threshold = 1.0
32
+ layers[0].self_attn.sp_threshold = 1.0
33
+
34
+ return model
35
+
36
+
37
+ def build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length, split='test'):
38
+ """
39
+ Build a DataLoader for the specified dataset.
40
+
41
+ Args:
42
+ - model_name (str): The Hugging Face identifier of the model.
43
+ - dataset_name (str): The name of the dataset.
44
+ - dataset_config (str): The configuration of the dataset.
45
+ - batch_size (int): The batch size for the DataLoader.
46
+ - max_length (int): The maximum sequence length.
47
+ - split (str): The split of the dataset to use (default='test'). options: 'train', 'validation', 'test'
48
+
49
+ Returns:
50
+ - dataloader (DataLoader): The DataLoader for the specified dataset.
51
+ """
52
+ # Load tokenizer and model
53
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
54
+ if tokenizer.pad_token_id is None:
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+ # Load dataset
57
+ dataset = load_dataset(dataset_name, dataset_config, split=split)
58
+ dataset = dataset.filter(lambda x: len(x["text"]) >= max_length)
59
+
60
+ # Tokenize the dataset
61
+ def tokenize_function(examples):
62
+ return tokenizer(examples['text'], return_special_tokens_mask=True, truncation=True, max_length=max_length)
63
+
64
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text'])
65
+
66
+ # Create DataLoader
67
+ def collate_fn(batch):
68
+ input_ids = [torch.tensor(example['input_ids']) for example in batch]
69
+ attention_mask = [torch.tensor(example['attention_mask']) for example in batch]
70
+
71
+ # Pad sequences
72
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
73
+ attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
74
+
75
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
76
+
77
+ dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
78
+
79
+ return dataloader
80
+
81
+ def compute_sparse_perplexity(model_name='facebook/opt-125m', dataset_name='wikitext', dataset_config='wikitext-2-raw-v1', batch_size=8, max_length=512, attn_th=0.0, static_thresholds=True, device_map="cuda:0"):
82
+ # Set device
83
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
+ print(f'Using device: {device}')
85
+
86
+ # load the activation thresholds
87
+ num_layers = CONFIGS[model_name]['num_layer']
88
+
89
+ sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=attn_th)
90
+
91
+ if not static_thresholds:
92
+ act_threshold_filepath = CONFIGS[model_name]['sp_config']
93
+ sp_thresholds.load_thresholds(act_threshold_filepath)
94
+ print(f'Activation thresholds loaded from {act_threshold_filepath}')
95
+
96
+ print(sp_thresholds.activation_threshold)
97
+
98
+ # Load tokenizer and model
99
+ # tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
100
+ model = SparseOPTForCausalLM.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
101
+ model.eval()
102
+
103
+ # # Load dataset
104
+ dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length)
105
+
106
+ perplexity = compute_perplexity(model, dataloader, device)
107
+
108
+ return perplexity
109
+
110
+ def compute_perplexity_data_collection(model_name='facebook/opt-125m', dataset_name='wikitext', dataset_config='wikitext-2-raw-v1', batch_size=8, max_length=512, device_map="cuda:0"):
111
+ # Set device
112
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
113
+ print(f'Using device: {device}')
114
+
115
+ # Load dataset
116
+ dataset = load_dataset(dataset_name, dataset_config, split='test')
117
+ dataset = dataset.filter(lambda x: len(x["text"]) >= 512)
118
+ dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length)
119
+
120
+ attn_thresholds = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
121
+ # attn_thresholds = [1, 0.5, 0.2]
122
+
123
+ print(f"Computing perplexity for the following attention thresholds: {attn_thresholds}")
124
+
125
+ # load the model
126
+ num_layers = CONFIGS[model_name]['num_layer']
127
+
128
+ sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=0.1)
129
+ model = SparseOPTForCausalLM.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
130
+ model.eval()
131
+
132
+ results = []
133
+ for attn_th in attn_thresholds:
134
+
135
+ print(f'Computing perplexity for attn top k: {attn_th}')
136
+
137
+ # update the model with new threshold
138
+ model = _update_model_attn_thresholds(model, attn_th)
139
+
140
+ # compute and store the perplexity
141
+ perplexity = compute_perplexity(model, dataloader, device)
142
+ print(f'Perplexity: {perplexity:.2f}\n')
143
+ results.append({
144
+ 'model': model_name,
145
+ 'top_k': attn_th,
146
+ 'perplexity': perplexity
147
+ })
148
+
149
+
150
+ # save the results to a csv file
151
+ results_df = pd.DataFrame(results)
152
+ model_name_str = extract_model_name(model_name)
153
+
154
+ # save the results to a csv file in the results directory
155
+ results_df.to_csv(f'{results_dir}/sparse_perplexity_results_{model_name_str}_topk.csv', index=False)
156
+
157
+
158
+ def display_model_menu():
159
+ """
160
+ Displays a numbered menu of available OPT models and prompts the user to make a selection.
161
+
162
+ Returns:
163
+ - selected_model (str): The Hugging Face identifier of the selected model.
164
+ """
165
+ print("Available OPT Models:")
166
+ for idx, model in enumerate(MODELS, 1):
167
+ print(f"{idx}. {model}")
168
+
169
+ while True:
170
+ try:
171
+ choice = input("\nEnter the number corresponding to the model you want to evaluate (e.g., 1): ")
172
+ if choice.lower() in ['q', 'quit', 'exit']:
173
+ print("Exiting the program.")
174
+ sys.exit(0)
175
+ choice = int(choice)
176
+ if 1 <= choice <= len(MODELS):
177
+ selected_model = MODELS[choice - 1]
178
+ print(f"\nYou have selected: {selected_model}\n")
179
+ return selected_model
180
+ else:
181
+ print(f"Please enter a number between 1 and {len(MODELS)}.")
182
+ except ValueError:
183
+ print("Invalid input. Please enter a valid number.")
184
+
185
+
186
+ def _interactive_mode():
187
+ selected_model = display_model_menu()
188
+
189
+ # Optional: Allow user to adjust batch size and max sequence length
190
+ try:
191
+ batch_size_input = input("Enter batch size (default=8): ").strip()
192
+ batch_size = int(batch_size_input) if batch_size_input else 8
193
+ except ValueError:
194
+ print("Invalid input for batch size. Using default value of 8.")
195
+ batch_size = 8
196
+
197
+ max_length = 512
198
+
199
+ try:
200
+ data_collection = input("Do you want to collect data for different activation thresholds? (y/n): ").strip()
201
+ data_collection = True if data_collection.lower() == 'y' else False
202
+ except ValueError:
203
+ print("Invalid input for data collection. Using default value of False.")
204
+ data_collection = False
205
+
206
+
207
+ # select device
208
+ device_map = input("Enter device (cuda:0/auto) [default=cuda:0]: ").strip()
209
+ if not device_map:
210
+ device_map = "cuda:0"
211
+
212
+ # select attention threshold
213
+ attn_th = 0.0
214
+ if not data_collection:
215
+ try:
216
+ attn_th = input("Enter activation threshold for attention layers: ").strip()
217
+ attn_th = float(attn_th) if attn_th else 0.0
218
+ except ValueError:
219
+ print("Invalid input for attention threshold. Using default value of 0.")
220
+ attn_th = 0.0
221
+
222
+ return selected_model, batch_size, max_length, data_collection, device_map, attn_th
223
+
224
+
225
+ def arg_parser():
226
+ parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation')
227
+ parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate')
228
+ parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation')
229
+ parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length')
230
+ parser.add_argument('--attn_th', type=float, default=0.0, help='Activation threshold for attention layers')
231
+ parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds')
232
+ parser.add_argument('--device_map', type=str, default='cuda:0', help='Device to use for evaluation')
233
+ parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection')
234
+
235
+ return parser.parse_args()
236
+
237
+ def main():
238
+ """
239
+ Main function to execute the perplexity computation with user-selected OPT model.
240
+ """
241
+ print("=== OPT Models Perplexity Evaluation ===\n")
242
+ args = arg_parser()
243
+
244
+ if args.interactive:
245
+ selected_model, batch_size, max_length, data_collection, device_map, attn_th = _interactive_mode()
246
+
247
+ else:
248
+ selected_model, batch_size, max_length, data_collection, device_map, attn_th = MODELS[args.model_index-1], args.batch_size, args.max_length, args.data_collection, args.device_map, args.attn_th
249
+ print(f"Selected model: {selected_model}, batch size: {batch_size}, max length: {max_length}, attn_th: {attn_th}, data_collection: {data_collection}, device: {device_map}")
250
+
251
+ if data_collection:
252
+ print("\nStarting data collection...\n")
253
+ compute_perplexity_data_collection(model_name=selected_model, batch_size=batch_size, max_length=max_length, device_map=device_map)
254
+ print("\nData collection complete.\n")
255
+
256
+ else:
257
+ print("\nStarting perplexity computation...\n")
258
+ perplexity = compute_sparse_perplexity(model_name=selected_model, batch_size=batch_size, max_length=max_length, attn_th=attn_th, device_map=device_map)
259
+ print(f"\n=== Perplexity Results ===")
260
+ print(f"Model: {selected_model}")
261
+ print(f"Perplexity: {perplexity:.2f}\n")
262
+
263
+ if __name__ == "__main__":
264
+ main()
HybridTensor/benchmarks/select_block_decode.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from tests.test_select_block import create_block, Config, SparseConfig
4
+ import csv
5
+ import time
6
+ import torch
7
+ import torch.nn as nn
8
+ from flash_attn.utils.generation import InferenceParams
9
+ from HybridTensor.utils.utils import arg_parser, _get_device, sparse_index, generate_random_BH_index, get_gpu_name
10
+ from HybridTensor.utils.profiling import cuda_profiler
11
+ import math
12
+ from tqdm import tqdm
13
+
14
+ def run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype):
15
+ config = Config()
16
+ sp_config = SparseConfig()
17
+ sp_config.attn_topk = attn_topk
18
+
19
+ config.hidden_size = args.in_features
20
+ config.num_attention_heads = args.in_features // 128
21
+ config.use_heuristic = False # use pre-compiled heuristic or complie new one during runtime
22
+
23
+ # Test create_block
24
+ sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype)
25
+ sparse_block.eval()
26
+ sparse_block.mlp_topk = index_size
27
+
28
+ regular_config = config
29
+ regular_config.att_sparse = False
30
+ regular_config.mlp_sparse = False
31
+ regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype)
32
+ regular_block.eval()
33
+
34
+ # inference simulation with select block
35
+ max_seqlen = seq_len + 16
36
+ max_batch_size = batch_size
37
+ in_features = args.in_features
38
+ head_dim = 128
39
+
40
+ inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size)
41
+ process_group = None
42
+ sequence_parallel = False
43
+
44
+ # for testing and debugging
45
+ heads = config.num_attention_heads
46
+ selected_heads = heads // 2
47
+
48
+ # Create a static index vector (length equals total columns in B).
49
+ total_neurons = args.in_features * 4
50
+ test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32)
51
+ active_indices = sparse_index(args.index_size, total_neurons)[0]
52
+ test_index_vec[:args.index_size] = active_indices
53
+ if args.index_size < total_neurons:
54
+ test_index_vec[args.index_size:] = 0 # Fill the rest with dummy values.
55
+
56
+ # test_index_vec = sparse_index(args.in_features, args.in_features*4)[0].cuda()
57
+ test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads)
58
+ test_index_size = args.index_size
59
+
60
+ mixer_kwargs = (
61
+ {"seqlen": seq_len}
62
+ if process_group is not None and sequence_parallel
63
+ else {}
64
+ )
65
+ if inference_params is not None:
66
+ mixer_kwargs["inference_params"] = inference_params
67
+
68
+ with torch.no_grad():
69
+ # prefill stage
70
+ original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16)
71
+
72
+ # Test prefill
73
+ output_sparse = sparse_block(original_seq, mixer_kwargs=mixer_kwargs)
74
+ output_regular = regular_block(original_seq, mixer_kwargs=mixer_kwargs)
75
+
76
+ # need to update inference_params to reflect the new sequence length
77
+ mixer_kwargs["inference_params"].seqlen_offset = seq_len
78
+
79
+ # Decode stage
80
+ input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16)
81
+
82
+ out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs)
83
+
84
+ mixer_kwargs["inference_params"].seqlen_offset = seq_len
85
+
86
+ out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs)
87
+
88
+ # mesure decode stage time in ms
89
+ # print("Without CUDA Graphs")
90
+ # out_decode_regular, regular_time = cuda_profiler(regular_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2)
91
+ # print(f"Regular time: {regular_time} ms")
92
+
93
+ # out_decode_sparse, sparse_time = cuda_profiler(sparse_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2)
94
+ # print(f"Sparse time: {sparse_time} ms")
95
+
96
+ # speedup = regular_time / sparse_time
97
+ # print(f"Speedup: {speedup}")
98
+
99
+ # --- CUDA Graph Capture for Decode Stage ---
100
+ # Allocate static buffer for regular block (shape assumed fixed)
101
+ input_x_static = input_x.clone()
102
+ output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype)
103
+
104
+ # Capture regular block graph
105
+ _ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs)
106
+ torch.cuda.synchronize()
107
+ graph_regular = torch.cuda.CUDAGraph()
108
+ with torch.cuda.graph(graph_regular):
109
+ res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs)
110
+ if isinstance(res, tuple):
111
+ res = res[0]
112
+ output_regular_static.copy_(res)
113
+
114
+ # For the sparse block, run a dummy call to determine its output shape.
115
+ # Also, reset the inference parameter to ensure consistent behavior.
116
+ mixer_kwargs["inference_params"].seqlen_offset = seq_len
117
+ temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs)
118
+ if isinstance(temp, tuple):
119
+ temp = temp[0]
120
+ # print("Captured sparse block output shape:", temp.shape)
121
+ # Allocate static buffer matching the dummy run's shape.
122
+ output_sparse_static = torch.empty_like(temp)
123
+ # print("output_sparse_static shape:", output_sparse_static.shape)
124
+ torch.cuda.synchronize()
125
+
126
+ mixer_kwargs["inference_params"].seqlen_offset = seq_len
127
+ graph_sparse = torch.cuda.CUDAGraph()
128
+ with torch.cuda.graph(graph_sparse):
129
+ res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs)
130
+ if isinstance(res, tuple):
131
+ res = res[0]
132
+ output_sparse_static.copy_(res)
133
+
134
+ # Warmup CUDA Graph replays
135
+ for _ in range(5):
136
+ graph_regular.replay()
137
+ graph_sparse.replay()
138
+ torch.cuda.synchronize()
139
+
140
+ # --- Measure CUDA Graph Replay Latency ---
141
+ num_replays = 10
142
+
143
+ start = time.time()
144
+ for _ in range(num_replays):
145
+ graph_regular.replay()
146
+ torch.cuda.synchronize()
147
+ regular_graph_time = (time.time() - start) * 1000 / num_replays
148
+
149
+ start = time.time()
150
+ for _ in range(num_replays):
151
+ graph_sparse.replay()
152
+ torch.cuda.synchronize()
153
+ sparse_graph_time = (time.time() - start) * 1000 / num_replays
154
+ speedup = regular_graph_time / sparse_graph_time
155
+ # print()
156
+ # print("With CUDA Graphs")
157
+ # print(f"Regular block time (CUDA Graphs): {regular_graph_time} ms")
158
+ # print(f"Sparse block time (CUDA Graphs): {sparse_graph_time} ms")
159
+ # print(f"Speedup (CUDA Graphs): {speedup}")
160
+
161
+ return regular_graph_time, sparse_graph_time, speedup
162
+
163
+ if __name__ == "__main__":
164
+
165
+ args = arg_parser()
166
+ device = _get_device(0)
167
+ dtype = torch.float16
168
+ gpu_name = get_gpu_name()
169
+
170
+ # Parameter grids.
171
+ # batch_sizes = [1, 4, 8, 16]
172
+ # seq_lengths = [128, 512]
173
+ # index_sizes = [512, 1024, 2048, 4096]
174
+ # attn_topks = [0.3, 0.4, 0.5]
175
+
176
+ batch_sizes = [1, 8, 16, 32]
177
+ seq_lengths = [1024, 2048]
178
+ # index_sizes = [512, 1024, 2048, 4096, 8192]
179
+ index_size_p = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
180
+ total_neurons = args.in_features * 4
181
+
182
+ # Calculate initial index_size values
183
+ index_sizes = [int(total_neurons * i) for i in index_size_p]
184
+
185
+ # Round up to the nearest multiple of 128 if necessary
186
+ index_sizes = [math.ceil(size / 128) * 128 if size % 128 != 0 else size for size in index_sizes]
187
+
188
+ attn_topks = [0.3, 0.4, 0.5]
189
+
190
+ # Calculate total number of simulations.
191
+ total_runs = len(batch_sizes) * len(seq_lengths) * len(index_sizes) * len(attn_topks)
192
+ output_file = f"results/simulations/{gpu_name}_select_block_{args.in_features}_inference_sim.csv"
193
+
194
+ with open(output_file, mode='w', newline='') as csv_file:
195
+ fieldnames = ["in_features", "batch_size", "seq_len", "index_size", "neuron_activation", "attn_topk",
196
+ "regular_graph_time_ms", "sparse_graph_time_ms", "speedup"]
197
+ writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
198
+ writer.writeheader()
199
+
200
+ # Iterate over all combinations with tqdm progress bar.
201
+ for batch_size in tqdm(batch_sizes, desc="Batch Sizes"):
202
+ for seq_len in seq_lengths:
203
+ for index_size in index_sizes:
204
+ for attn_topk in attn_topks:
205
+ reg_time, spa_time, speedup = run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype)
206
+ writer.writerow({
207
+ "in_features": args.in_features,
208
+ "batch_size": batch_size,
209
+ "seq_len": seq_len,
210
+ "index_size": index_size,
211
+ "neuron_activation": index_size / total_neurons,
212
+ "attn_topk": attn_topk,
213
+ "regular_graph_time_ms": reg_time,
214
+ "sparse_graph_time_ms": spa_time,
215
+ "speedup": speedup
216
+ })
217
+ csv_file.flush()
218
+ print(f"Simulation complete. Results saved to {output_file}")
HybridTensor/models/__pycache__/create_sparse_model.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
HybridTensor/models/__pycache__/create_sparse_model.cpython-39.pyc ADDED
Binary file (18.4 kB). View file
 
HybridTensor/models/__pycache__/helper.cpython-310.pyc ADDED
Binary file (4.52 kB). View file
 
HybridTensor/models/__pycache__/helper.cpython-39.pyc ADDED
Binary file (4.65 kB). View file
 
HybridTensor/models/__pycache__/llama.cpython-39.pyc ADDED
Binary file (2.51 kB). View file
 
HybridTensor/models/__pycache__/opt.cpython-310.pyc ADDED
Binary file (4.9 kB). View file
 
HybridTensor/models/__pycache__/opt.cpython-39.pyc ADDED
Binary file (5.19 kB). View file
 
HybridTensor/models/create_sparse_model.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from functools import partial
5
+
6
+ from einops import rearrange
7
+
8
+ from transformers import GPT2Config
9
+ from collections import namedtuple
10
+ from HybridTensor.modules.SelectiveMHA import SMHA, SelectMHA, ParallelSelectMHA, MHARouter, ParallelMHARouter
11
+ from HybridTensor.modules.SelectiveMLP import SelectiveMLP, ParallelSelectiveMLP, MLPRouter, ParallelMLPRouter
12
+ from HybridTensor.modules.SelectiveBlock import SelectBlock
13
+ # from HybridTensor.modules.SelectiveBlock_v1 import SelectBlock
14
+ import torch.nn.functional as F
15
+ from flash_attn.utils.distributed import (
16
+ all_gather,
17
+ all_gather_raw,
18
+ get_dim_for_local_rank,
19
+ sync_shared_params,
20
+ )
21
+
22
+ from collections.abc import Sequence
23
+ from flash_attn.modules.mha import MHA, ParallelMHA
24
+ from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP, GatedMlp, ParallelGatedMlp, Mlp, ParallelMLP
25
+ from flash_attn.ops.activations import sqrelu_fwd
26
+ from flash_attn.modules.block import Block
27
+
28
+ try:
29
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
30
+ except ImportError:
31
+ layer_norm_fn, RMSNorm = None, None
32
+
33
+ from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
34
+ from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
35
+ from flash_attn.utils.pretrained import state_dict_from_pretrained
36
+ from flash_attn.utils.generation import GenerationMixin
37
+ from flash_attn.models.opt import remap_state_dict_hf_opt
38
+
39
+ try:
40
+ from flash_attn.ops.fused_dense import ColumnParallelLinear
41
+ except ImportError:
42
+ ColumnParallelLinear = None
43
+
44
+ try:
45
+ from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
46
+ except ImportError:
47
+ FusedDenseSqreluDense = None
48
+
49
+ try:
50
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
51
+ except ImportError:
52
+ layer_norm_fn, RMSNorm = None, None
53
+
54
+ from HybridTensor.models.helper import remap_state_dict_gpt2, shard_state_dict_tp
55
+
56
+ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
57
+ factory_kwargs = {"device": device, "dtype": dtype}
58
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
59
+ attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0
60
+ softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power))
61
+ softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0)
62
+ if config.scale_attn_by_inverse_layer_idx:
63
+ assert layer_idx is not None
64
+ softmax_scale /= float(layer_idx + 1)
65
+ dwconv = getattr(config, "attn_dwconv", False)
66
+ if dwconv:
67
+ assert process_group is None, "TensorParallel MHA does not support dwconv yet"
68
+ qkv_proj_bias = getattr(config, "qkv_proj_bias", True)
69
+ out_proj_bias = getattr(config, "out_proj_bias", True)
70
+ rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim)
71
+ rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
72
+ rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
73
+ rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
74
+ use_alibi = getattr(config, "use_alibi", False)
75
+ use_triton = getattr(config, "use_triton", True) # toggle cuda or triton decode kernels
76
+ window_size = getattr(config, "window_size", (-1, -1))
77
+ use_flash_attn = getattr(config, "use_flash_attn", False)
78
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
79
+ if not fused_bias_fc:
80
+ assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
81
+
82
+ mlp_sparse = getattr(config, "mlp_sparse", False)
83
+ att_sparse = getattr(config, "att_sparse", False)
84
+ num_heads = getattr(config, "num_attention_heads", None)
85
+ n_head_kv = getattr(config, "n_head_kv", num_heads)
86
+
87
+ if num_heads != n_head_kv:
88
+ att_sparse = False
89
+
90
+ if process_group is None:
91
+ mha_cls = SMHA # SelectMHA if att_sparse else MHA
92
+ else:
93
+ mha_cls = ParallelSelectMHA if att_sparse else ParallelMHA
94
+
95
+ # mha_cls = SelectMHA if process_group is None else ParallelSelectMHA
96
+ serial_kwargs = (
97
+ {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {}
98
+ )
99
+ parallel_kwargs = (
100
+ {
101
+ "process_group": process_group,
102
+ "sequence_parallel": getattr(config, "sequence_parallel", False),
103
+ }
104
+ if process_group is not None
105
+ else {}
106
+ )
107
+ num_heads_kv = getattr(config, "n_head_kv", None)
108
+ mixer_cls = partial(
109
+ mha_cls,
110
+ num_heads=config.num_attention_heads,
111
+ num_heads_kv=num_heads_kv,
112
+ qkv_proj_bias=qkv_proj_bias,
113
+ out_proj_bias=out_proj_bias,
114
+ dropout=config.attn_pdrop,
115
+ softmax_scale=softmax_scale,
116
+ causal=True,
117
+ layer_idx=layer_idx,
118
+ rotary_emb_dim=rotary_emb_dim,
119
+ rotary_emb_base=rotary_emb_base,
120
+ rotary_emb_scale_base=rotary_emb_scale_base,
121
+ rotary_emb_interleaved=rotary_emb_interleaved,
122
+ use_alibi=use_alibi,
123
+ window_size=window_size,
124
+ use_flash_attn=use_flash_attn,
125
+ **serial_kwargs,
126
+ **parallel_kwargs,
127
+ **factory_kwargs,
128
+ )
129
+ return mixer_cls
130
+
131
+ def create_mlp_cls_old(config, layer_idx=None, process_group=None, device=None, dtype=None):
132
+ factory_kwargs = {"device": device, "dtype": dtype}
133
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
134
+ fused_mlp = getattr(config, "fused_mlp", False)
135
+ if fused_mlp:
136
+ assert config.activation_function in [
137
+ "gelu_new",
138
+ "gelu_fast",
139
+ "gelu_approx",
140
+ "gelu_pytorch_tanh",
141
+ "relu",
142
+ "sqrelu",
143
+ ]
144
+ assert fused_mlp == True, "Not supported not fused mlp for now"
145
+
146
+ mlp_sparse = getattr(config, "mlp_sparse", False)
147
+ use_heuristic = getattr(config, "use_heuristic", True)
148
+
149
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
150
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
151
+ if isinstance(mlp_checkpoint_lvl, Sequence):
152
+ assert layer_idx is not None
153
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
154
+
155
+ if fused_mlp:
156
+ if FusedMLP is None:
157
+ raise ImportError("fused_dense is not installed")
158
+ # activation = (
159
+ # "gelu_approx"
160
+ # if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
161
+ # else "relu"
162
+ # )
163
+
164
+ if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]:
165
+ activation = "gelu_approx"
166
+ else:
167
+ activation = "relu" # config.activation_function
168
+
169
+ if process_group is None:
170
+ mlp_cls = SelectiveMLP if mlp_sparse else FusedMLP
171
+ else:
172
+ mlp_cls = ParallelSelectiveMLP if mlp_sparse else ParallelFusedMLP
173
+
174
+ parallel_kwargs = (
175
+ {
176
+ "process_group": process_group,
177
+ "sequence_parallel": getattr(config, "sequence_parallel", True),
178
+ }
179
+ if process_group is not None
180
+ else {}
181
+ )
182
+
183
+ sparsity_kwargs = (
184
+ {
185
+ "use_heuristic": use_heuristic,
186
+ }
187
+ if mlp_sparse
188
+ else {}
189
+ )
190
+
191
+ mlp_cls = partial(
192
+ mlp_cls,
193
+ hidden_features=inner_dim,
194
+ activation=activation,
195
+ checkpoint_lvl=mlp_checkpoint_lvl,
196
+ # layer_idx=layer_idx,
197
+ **parallel_kwargs,
198
+ **factory_kwargs,
199
+ **sparsity_kwargs,
200
+ )
201
+
202
+ else:
203
+ raise RuntimeError("MLP type not supported")
204
+ return mlp_cls
205
+
206
+ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
207
+ """
208
+ Create an MLP class that supports both sparse MLPs (via fused mlp) and GatedMLPs.
209
+ If the activation function is one of "glu", "swiglu", or "geglu", then GatedMlp is used
210
+ (and mlp_sparse is ignored). Otherwise, fused_mlp is used to decide between sparse and
211
+ dense implementations.
212
+ """
213
+ from functools import partial
214
+ factory_kwargs = {"device": device, "dtype": dtype}
215
+ mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True)
216
+ mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True)
217
+
218
+
219
+ # Check for gated activations
220
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
221
+ # For gated activations we do not support sparsity yet.
222
+ activation = (
223
+ F.sigmoid if config.activation_function == "glu"
224
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
225
+ )
226
+ mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
227
+ parallel_kwargs = (
228
+ {"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)}
229
+ if process_group is not None else {}
230
+ )
231
+ mlp_multiple_of = getattr(config, "mlp_multiple_of", 128)
232
+ mlp_cls = partial(
233
+ mlp_cls,
234
+ hidden_features=config.n_inner,
235
+ activation=activation,
236
+ bias1=mlp_fc1_bias,
237
+ bias2=mlp_fc2_bias,
238
+ multiple_of=mlp_multiple_of,
239
+ **parallel_kwargs,
240
+ **factory_kwargs,
241
+ )
242
+ return mlp_cls
243
+
244
+ # For non-gated activations:
245
+ fused_mlp = getattr(config, "fused_mlp", False)
246
+ fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False)
247
+ if fused_dense_sqrelu_dense:
248
+ assert config.activation_function == "sqrelu", (
249
+ "fused_dense_sqrelu_dense only supports approximate activation_function sqrelu"
250
+ )
251
+ assert not (fused_dense_sqrelu_dense and fused_mlp)
252
+
253
+ if fused_mlp:
254
+ # Ensure valid activation function.
255
+ assert config.activation_function in [
256
+ "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu"
257
+ ]
258
+ # Support checkpoint level (possibly a list)
259
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
260
+ if isinstance(mlp_checkpoint_lvl, (list, tuple)):
261
+ assert layer_idx is not None
262
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
263
+ # Choose activation string.
264
+ if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]:
265
+ activation = "gelu_approx"
266
+ else:
267
+ activation = "relu"
268
+ # Determine inner dim.
269
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
270
+ mlp_sparse = getattr(config, "mlp_sparse", False)
271
+ use_heuristic = getattr(config, "use_heuristic", True)
272
+ if process_group is None:
273
+ mlp_cls = SelectiveMLP if mlp_sparse else FusedMLP
274
+ else:
275
+ mlp_cls = ParallelSelectiveMLP if mlp_sparse else ParallelFusedMLP
276
+ parallel_kwargs = (
277
+ {"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)}
278
+ if process_group is not None else {}
279
+ )
280
+ sparsity_kwargs = {"use_heuristic": use_heuristic} if mlp_sparse else {}
281
+ mlp_cls = partial(
282
+ mlp_cls,
283
+ hidden_features=inner_dim,
284
+ activation=activation,
285
+ checkpoint_lvl=mlp_checkpoint_lvl,
286
+ bias1=mlp_fc1_bias,
287
+ bias2=mlp_fc2_bias,
288
+ **parallel_kwargs,
289
+ **factory_kwargs,
290
+ **sparsity_kwargs,
291
+ )
292
+ return mlp_cls
293
+
294
+ elif fused_dense_sqrelu_dense:
295
+ if process_group is not None:
296
+ assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense"
297
+ assert FusedDenseSqreluDense is not None
298
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
299
+ if isinstance(mlp_checkpoint_lvl, (list, tuple)):
300
+ assert layer_idx is not None
301
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
302
+ mlp_cls = partial(
303
+ FusedDenseSqreluDense,
304
+ hidden_features=config.n_inner,
305
+ checkpoint_lvl=mlp_checkpoint_lvl,
306
+ **factory_kwargs,
307
+ )
308
+ return mlp_cls
309
+
310
+ else:
311
+ # Non-fused, non-sparse branch.
312
+ assert config.activation_function in [
313
+ "gelu", "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu"
314
+ ]
315
+ if config.activation_function == "relu":
316
+ activation = partial(F.relu, inplace=True)
317
+ elif config.activation_function == "sqrelu":
318
+ activation = sqrelu_fwd
319
+ else:
320
+ approximate = "tanh" if config.activation_function in [
321
+ "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"
322
+ ] else "none"
323
+ activation = partial(F.gelu, approximate=approximate)
324
+ mlp_sparse = getattr(config, "mlp_sparse", False)
325
+ mlp_cls = Mlp if process_group is None else ParallelMLP
326
+ parallel_kwargs = (
327
+ {"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)}
328
+ if process_group is not None else {}
329
+ )
330
+ mlp_cls = partial(
331
+ mlp_cls,
332
+ hidden_features=config.n_inner,
333
+ activation=activation,
334
+ bias1=mlp_fc1_bias,
335
+ bias2=mlp_fc2_bias,
336
+ **parallel_kwargs,
337
+ **factory_kwargs,
338
+ )
339
+ return mlp_cls
340
+
341
+ def create_mlp_router_cls(config, sp_config = None, layer_idx=None, process_group=None, device=None, dtype=None):
342
+ factory_kwargs = {"device": device, "dtype": dtype}
343
+ num_neurons = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
344
+
345
+ # this can be made different per layer by adding mlp_low_rank_dim_{layer_idx} in the sp_config
346
+ low_rank_dim = getattr(sp_config, "mlp_low_rank_dim", 1024)
347
+
348
+ # per layer activation threshold
349
+ act_th = getattr(config, "mlp_act_th", 0.5)
350
+
351
+ if process_group is None:
352
+ mlp_router_cls = MLPRouter
353
+ else:
354
+ mlp_router_cls = ParallelMLPRouter
355
+
356
+ parallel_kwargs = (
357
+ {
358
+ "process_group": process_group,
359
+ "sequence_parallel": getattr(config, "sequence_parallel", True),
360
+ }
361
+ if process_group is not None
362
+ else {}
363
+ )
364
+
365
+ mlp_router_cls = partial(mlp_router_cls,
366
+ low_rank_dim = low_rank_dim,
367
+ out_dim = num_neurons,
368
+ act_th = act_th,
369
+ **parallel_kwargs,
370
+ **factory_kwargs)
371
+
372
+ return mlp_router_cls
373
+
374
+ def create_mha_router_cls(config, sp_config = None, layer_idx=None, process_group=None, device=None, dtype=None):
375
+ factory_kwargs = {"device": device, "dtype": dtype}
376
+ num_heads = config.num_attention_heads
377
+ n_head_kv = getattr(config, "n_head_kv", num_heads)
378
+ if num_heads != n_head_kv:
379
+ out_dim = n_head_kv
380
+ else:
381
+ out_dim = num_heads
382
+
383
+ low_rank_dim = getattr(sp_config, "attn_low_rank_dim", 128) # optional, default to 128
384
+
385
+ # per layer activation topk, to make this different per layer, add a different attn_topk_{layer_idx} in the sp_config
386
+ attn_topk = getattr(sp_config, "attn_topk", 0.5)
387
+ if process_group is None:
388
+ mha_router_cls = MHARouter
389
+ else:
390
+ mha_router_cls = ParallelMHARouter
391
+
392
+ parallel_kwargs = (
393
+ {
394
+ "process_group": process_group,
395
+ "sequence_parallel": getattr(config, "sequence_parallel", True),
396
+ }
397
+ if process_group is not None
398
+ else {}
399
+ )
400
+
401
+
402
+ mha_router_cls = partial(mha_router_cls,
403
+ low_rank_dim = low_rank_dim,
404
+ out_dim = out_dim,
405
+ top_k = attn_topk,
406
+ **parallel_kwargs,
407
+ **factory_kwargs)
408
+
409
+ return mha_router_cls
410
+
411
+ def create_block(config, sp_config, layer_idx=None, process_group=None, device=None, dtype=None):
412
+ factory_kwargs = {"device": device, "dtype": dtype}
413
+ sequence_parallel = getattr(config, "sequence_parallel", True)
414
+ mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
415
+ mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
416
+
417
+ use_rms_norm = getattr(config, "rms_norm", False)
418
+ norm_cls = partial(
419
+ nn.LayerNorm if not use_rms_norm else RMSNorm,
420
+ eps=config.layer_norm_epsilon,
421
+ **factory_kwargs,
422
+ )
423
+
424
+ # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
425
+ residual_in_fp32 = getattr(config, "residual_in_fp32", False)
426
+ resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
427
+ prenorm = getattr(config, "prenorm", True)
428
+ parallel_block = getattr(config, "parallel_block", False)
429
+ mlp_sparse = getattr(config, "mlp_sparse", False)
430
+ att_sparse = getattr(config, "att_sparse", False)
431
+ block_sparse = mlp_sparse or att_sparse
432
+
433
+ if not parallel_block:
434
+ if block_sparse:
435
+ mha_router_cls = create_mha_router_cls(config, sp_config, layer_idx, process_group=process_group, **factory_kwargs) if att_sparse else None
436
+ mlp_router_cls = create_mlp_router_cls(config, sp_config, layer_idx, process_group=process_group, **factory_kwargs) if mlp_sparse else None
437
+
438
+ block = SelectBlock(
439
+ config.hidden_size,
440
+ mixer_cls,
441
+ mlp_cls,
442
+ mlp_router = mlp_router_cls,
443
+ mha_router = mha_router_cls,
444
+ norm_cls=norm_cls,
445
+ prenorm=prenorm,
446
+ resid_dropout1=resid_dropout1,
447
+ resid_dropout2=config.resid_pdrop,
448
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
449
+ residual_in_fp32=residual_in_fp32,
450
+ sequence_parallel=sequence_parallel and process_group is not None,
451
+ mark_shared_params=process_group is not None,
452
+ )
453
+ else:
454
+ block = Block(
455
+ config.hidden_size,
456
+ mixer_cls,
457
+ mlp_cls,
458
+ norm_cls=norm_cls,
459
+ prenorm=prenorm,
460
+ resid_dropout1=resid_dropout1,
461
+ resid_dropout2=config.resid_pdrop,
462
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
463
+ residual_in_fp32=residual_in_fp32,
464
+ sequence_parallel=sequence_parallel and process_group is not None,
465
+ mark_shared_params=process_group is not None,
466
+ )
467
+
468
+ else:
469
+ # not implemented
470
+ raise RuntimeError("ParallelBlock not implemented")
471
+ block.layer_idx = layer_idx
472
+ return block
473
+
474
+
475
+ class GPTPreTrainedModel(nn.Module):
476
+ """An abstract class to handle weights initialization and
477
+ a simple interface for dowloading and loading pretrained models.
478
+ """
479
+
480
+ def __init__(self, config, *inputs, **kwargs):
481
+ super().__init__()
482
+ if not isinstance(config, GPT2Config):
483
+ raise ValueError(
484
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
485
+ "To create a model from a Google pretrained model use "
486
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
487
+ self.__class__.__name__, self.__class__.__name__
488
+ )
489
+ )
490
+ self.config = config
491
+
492
+ @classmethod
493
+ def from_pretrained(
494
+ cls,
495
+ model_name,
496
+ config,
497
+ sp_config,
498
+ *args,
499
+ strict=True,
500
+ device=None,
501
+ dtype=None,
502
+ world_size=1,
503
+ rank=0,
504
+ **kwargs,
505
+ ):
506
+ """
507
+ Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
508
+ Download and cache the pre-trained model file if needed.
509
+ """
510
+ # Instantiate model.
511
+ model = cls(config, sp_config, *args, device=device, dtype=dtype, **kwargs)
512
+ # Load state_dict in cpu because we already initialized the model in GPU, and we don't
513
+ # want extra stuff taking up more GPU memory
514
+ state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
515
+ if model_name.startswith("gpt2"):
516
+ state_dict = remap_state_dict_gpt2(state_dict, config)
517
+ elif model_name.startswith("facebook/opt"):
518
+ state_dict = remap_state_dict_hf_opt(state_dict, config)
519
+ else:
520
+ raise NotImplementedError(f"Model {model_name} not supported")
521
+ if world_size > 1:
522
+ state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
523
+ load_return = model.load_state_dict(state_dict, strict=strict)
524
+ # logger.info(load_return)
525
+ return model
526
+
527
+
528
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
529
+ def _init_weights(
530
+ module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True
531
+ ):
532
+ if isinstance(module, nn.Linear):
533
+ nn.init.normal_(module.weight, std=initializer_range)
534
+ if module.bias is not None:
535
+ nn.init.zeros_(module.bias)
536
+ elif isinstance(module, nn.Embedding):
537
+ nn.init.normal_(module.weight, std=initializer_range)
538
+
539
+ if rescale_prenorm_residual:
540
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
541
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
542
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
543
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
544
+ #
545
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
546
+ for name, p in module.named_parameters():
547
+ if name in ["out_proj.weight", "fc2.weight"]:
548
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
549
+ nn.init.normal_(
550
+ p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)
551
+ )
552
+
553
+
554
+ class GPTModel(GPTPreTrainedModel):
555
+ def __init__(self, config: GPT2Config, sp_config=None, process_group=None, device=None, dtype=None):
556
+ super().__init__(config)
557
+ factory_kwargs = {"device": device, "dtype": dtype}
558
+ self.process_group = process_group
559
+ self.sequence_parallel = getattr(config, "sequence_parallel", True)
560
+ assert config.activation_function in [
561
+ "gelu",
562
+ "gelu_new",
563
+ "gelu_fast",
564
+ "gelu_approx",
565
+ "relu",
566
+ "sqrelu",
567
+ "glu",
568
+ "swiglu",
569
+ "geglu",
570
+ ]
571
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
572
+ vocab_size = (
573
+ math.ceil(config.vocab_size / pad_vocab_size_multiple)
574
+ * pad_vocab_size_multiple
575
+ )
576
+ # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
577
+ self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
578
+ # These 2 options are for OPT-350m
579
+ self.prenorm = getattr(config, "prenorm", True)
580
+ use_rms_norm = getattr(config, "rms_norm", False)
581
+ word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
582
+
583
+ if process_group is None:
584
+ self.embeddings = GPT2Embeddings(
585
+ config.hidden_size,
586
+ vocab_size,
587
+ config.max_position_embeddings,
588
+ word_embed_proj_dim=word_embed_proj_dim,
589
+ **factory_kwargs,
590
+ )
591
+ else:
592
+ self.embeddings = ParallelGPT2Embeddings(
593
+ config.hidden_size,
594
+ vocab_size,
595
+ config.max_position_embeddings,
596
+ process_group=process_group,
597
+ sequence_parallel=self.sequence_parallel,
598
+ **factory_kwargs,
599
+ )
600
+
601
+
602
+ # We change the order of dropout, residual and layer norm:
603
+ # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
604
+ # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
605
+ # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
606
+ # nn.Dropout probabilities are changed.
607
+ # This is for performance reason: we can fuse dropout + add + layer_norm.
608
+ self.layers = nn.ModuleList(
609
+ [
610
+ create_block(
611
+ config, sp_config, layer_idx=i, process_group=process_group, **factory_kwargs
612
+ )
613
+ for i in range(config.num_hidden_layers)
614
+ ]
615
+ )
616
+
617
+
618
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
619
+ if self.fused_dropout_add_ln:
620
+ if layer_norm_fn is None:
621
+ raise ImportError("Triton is not installed")
622
+ if self.prenorm:
623
+ self.drop_f = nn.Dropout(config.resid_pdrop)
624
+ norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
625
+ # self.ln_f = nn.LayerNorm(
626
+ # config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
627
+ # )
628
+ self.ln_f = norm_cls(
629
+ config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
630
+ )
631
+
632
+
633
+ if process_group is not None:
634
+ for p in self.ln_f.parameters():
635
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
636
+ p._shared_params = True
637
+ # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
638
+ if self.sequence_parallel:
639
+ p._sequence_parallel = True
640
+
641
+ self.apply(
642
+ partial(
643
+ _init_weights,
644
+ n_layer=config.num_hidden_layers,
645
+ initializer_range=config.initializer_range,
646
+ )
647
+ )
648
+ self.tie_weights()
649
+
650
+ self.sparse = False
651
+ if config.mlp_sparse or config.att_sparse:
652
+ self.sparse = True
653
+
654
+ def tie_weights(self):
655
+ if self.process_group is not None:
656
+ sync_shared_params(self, self.process_group)
657
+
658
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
659
+ return {
660
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
661
+ for i, layer in enumerate(self.layers)
662
+ }
663
+
664
+ def forward(self, input_ids, position_ids=None, inference_params=None):
665
+ # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
666
+ # dimensions so that we can split on it easily, in case of small batch size.
667
+ # Only the attention layers need to know the seqlen.
668
+ embedding_kwargs = (
669
+ {"combine_batch_seqlen_dim": True}
670
+ if self.process_group is not None and self.sequence_parallel
671
+ else {}
672
+ )
673
+ hidden_states = self.embeddings(
674
+ input_ids, position_ids=position_ids, **embedding_kwargs
675
+ )
676
+ residual = None
677
+ mixer_kwargs = (
678
+ {"seqlen": input_ids.shape[1]}
679
+ if self.process_group is not None and self.sequence_parallel
680
+ else {}
681
+ )
682
+ if inference_params is not None:
683
+ mixer_kwargs["inference_params"] = inference_params
684
+ else:
685
+ mixer_kwargs["inference_params"] = None
686
+
687
+ # else:
688
+ for layer in self.layers:
689
+ if self.prenorm:
690
+ hidden_states, residual = layer(
691
+ hidden_states,
692
+ residual,
693
+ mixer_kwargs=mixer_kwargs,
694
+ )
695
+ else:
696
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
697
+
698
+ if self.prenorm:
699
+ if not self.fused_dropout_add_ln:
700
+ dropped = self.drop_f(hidden_states)
701
+ residual = (dropped + residual) if residual is not None else dropped
702
+ hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
703
+ else:
704
+ # Set prenorm=False here since we don't need the residual
705
+ if hidden_states.shape != residual.shape:
706
+ hidden_states = hidden_states.view(residual.shape)
707
+
708
+ hidden_states = layer_norm_fn(
709
+ hidden_states,
710
+ self.ln_f.weight,
711
+ self.ln_f.bias,
712
+ residual=residual,
713
+ x1=None,
714
+ eps=self.ln_f.eps,
715
+ dropout_p=self.drop_f.p if self.training else 0.0,
716
+ prenorm=False,
717
+ is_rms_norm=isinstance(self.ln_f, RMSNorm)
718
+ )
719
+ return hidden_states
720
+
721
+
722
+ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
723
+ def __init__(self, config: GPT2Config, sp_config = None, process_group=None, device=None, dtype=None):
724
+ factory_kwargs = {"device": device, "dtype": dtype}
725
+ super().__init__(config)
726
+ self.process_group = process_group
727
+
728
+ self.transformer = GPTModel(
729
+ config, sp_config, process_group=process_group, **factory_kwargs
730
+ )
731
+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
732
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
733
+ vocab_size = (
734
+ math.ceil(config.vocab_size / pad_vocab_size_multiple)
735
+ * pad_vocab_size_multiple
736
+ )
737
+ # This option is for OPT-350m
738
+ word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
739
+ embed_dim = (
740
+ config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
741
+ )
742
+ if word_embed_proj_dim is not None:
743
+ self.project_out = nn.Linear(
744
+ config.n_embd, embed_dim, bias=False, **factory_kwargs
745
+ )
746
+ else:
747
+ self.project_out = None
748
+ mup_width_scale = getattr(config, "mup_width_scale", 1.0)
749
+ mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0)
750
+ self.output_scale = mup_output_multiplier * mup_width_scale
751
+
752
+ if process_group is None:
753
+ self.lm_head = nn.Linear(
754
+ embed_dim, vocab_size, bias=False, **factory_kwargs
755
+ )
756
+ else:
757
+ if ColumnParallelLinear is None:
758
+ raise ImportError("fused_dense_lib is not installed")
759
+ self.lm_head = ColumnParallelLinear(
760
+ embed_dim,
761
+ vocab_size,
762
+ process_group,
763
+ bias=False,
764
+ sequence_parallel=getattr(config, "sequence_parallel", True),
765
+ **factory_kwargs,
766
+ )
767
+
768
+ self.norm_head = getattr(config, "norm_head", False)
769
+ # Initialize weights and apply final processing
770
+ self.apply(
771
+ partial(
772
+ _init_weights,
773
+ n_layer=config.num_hidden_layers,
774
+ initializer_range=config.initializer_range,
775
+ )
776
+ )
777
+ self.tie_weights()
778
+
779
+ def tie_weights(self):
780
+ if self.tie_word_embeddings:
781
+ self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight # llama does not use tied weights
782
+ if self.process_group is not None:
783
+ sync_shared_params(self, self.process_group)
784
+
785
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
786
+ return self.transformer.allocate_inference_cache(
787
+ batch_size, max_seqlen, dtype=dtype, **kwargs
788
+ )
789
+
790
+ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
791
+ """
792
+ input_ids: (batch, seqlen) int tensor
793
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
794
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
795
+ num_last_tokens: if > 0, only return the logits for the last n tokens
796
+ """
797
+ assert (
798
+ input_ids.ndim == 2
799
+ ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
800
+ b, slen = input_ids.shape
801
+ hidden_states = self.transformer(
802
+ input_ids, position_ids=position_ids, inference_params=inference_params
803
+ )
804
+ if inference_params is not None:
805
+ assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
806
+ if num_last_tokens > 0:
807
+ hidden_states = hidden_states[:, -num_last_tokens:]
808
+ if self.project_out is not None:
809
+ hidden_states = self.project_out(hidden_states)
810
+ if self.output_scale != 1.0:
811
+ hidden_states = hidden_states * self.output_scale
812
+ if not self.norm_head:
813
+ lm_logits = self.lm_head(hidden_states)
814
+ else:
815
+ lm_head_weight = F.normalize(self.lm_head.weight)
816
+ if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:
817
+ hidden_states = all_gather(hidden_states, self.lm_head.process_group)
818
+ lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)
819
+ # During inference, we want the full logit for sampling
820
+ if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
821
+ lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
822
+ lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
823
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
824
+ return CausalLMOutput(logits=lm_logits)
825
+
826
+ def load_state_dict(self, state_dict, strict=True):
827
+ # Remapping from our checkpoints that used a different ordering of layers in the block
828
+ # Previous: Attn / MLP -> Dropout -> Add -> LN
829
+ # Current: Dropout -> Add -> LN -> Attn / MLP
830
+ if "transformer.ln_0.weight" in state_dict:
831
+ n_layers = len(self.transformer.layers)
832
+ ln_weight = state_dict.pop(
833
+ f"transformer.layers.{n_layers - 1}.norm2.weight"
834
+ )
835
+ ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias")
836
+ state_dict["transformer.ln_f.weight"] = ln_weight
837
+ state_dict["transformer.ln_f.bias"] = ln_bias
838
+ for l in reversed(range(n_layers)):
839
+ ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight")
840
+ ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias")
841
+ state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight
842
+ state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias
843
+ if l > 0:
844
+ ln_weight = state_dict.pop(
845
+ f"transformer.layers.{l - 1}.norm2.weight"
846
+ )
847
+ ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias")
848
+ state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight
849
+ state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias
850
+ ln_weight = state_dict.pop("transformer.ln_0.weight")
851
+ ln_bias = state_dict.pop("transformer.ln_0.bias")
852
+ state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
853
+ state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
854
+ return super().load_state_dict(state_dict, strict=strict)
HybridTensor/models/helper.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ from collections import OrderedDict
4
+
5
+ from einops import rearrange
6
+
7
+
8
+ def remap_state_dict_gpt2(state_dict, config):
9
+ # Word embedding and position embedding
10
+ def key_mapping_pos_emb(key):
11
+ return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
12
+
13
+ state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
14
+ word_embeddings = state_dict.pop("wte.weight")
15
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
16
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
17
+ vocab_size = (
18
+ math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
19
+ )
20
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
21
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
22
+ )
23
+ state_dict["lm_head.weight"] = state_dict[
24
+ "transformer.embeddings.word_embeddings.weight"
25
+ ]
26
+
27
+ # LayerNorm
28
+ def key_mapping_ln(key):
29
+ key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
30
+ key = re.sub(
31
+ r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key
32
+ )
33
+ return key
34
+
35
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
36
+
37
+ # MLP
38
+ for d in range(config.num_hidden_layers):
39
+ W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight")
40
+ state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t()
41
+ W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight")
42
+ state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
43
+
44
+ def key_mapping_mlp(key):
45
+ key = re.sub(
46
+ r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key
47
+ )
48
+ key = re.sub(
49
+ r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key
50
+ )
51
+ return key
52
+
53
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
54
+
55
+ # Attention
56
+ for d in range(config.num_hidden_layers):
57
+ state_dict.pop(f"h.{d}.attn.bias") # We don't store this bias
58
+ Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
59
+ state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
60
+ Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
61
+ state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
62
+
63
+ def key_mapping_attn(key):
64
+ key = re.sub(
65
+ r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key
66
+ )
67
+ key = re.sub(
68
+ r"^h.(\d+).attn.c_proj.bias",
69
+ r"transformer.layers.\1.mixer.out_proj.bias",
70
+ key,
71
+ )
72
+ return key
73
+
74
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
75
+
76
+ return state_dict
77
+
78
+
79
+ def shard_state_dict_tp(state_dict, config, world_size, rank):
80
+ """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
81
+ with tensor parallel.
82
+ """
83
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
84
+ vocab_size = (
85
+ math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
86
+ )
87
+ assert vocab_size % world_size == 0
88
+ assert config.hidden_size % world_size == 0
89
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
90
+ assert inner_dim % world_size == 0
91
+
92
+ def shard_first_dim(state_dict, key):
93
+ x = state_dict[key]
94
+ dim = x.shape[0] // world_size
95
+ state_dict[key] = x[rank * dim : (rank + 1) * dim]
96
+
97
+ def shard_last_dim(state_dict, key):
98
+ x = state_dict[key]
99
+ dim = x.shape[-1] // world_size
100
+ state_dict[key] = x[..., rank * dim : (rank + 1) * dim]
101
+
102
+ def shard_qkv_headdim(state_dict, key):
103
+ x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
104
+ dim = x.shape[1] // world_size
105
+ state_dict[key] = rearrange(
106
+ x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..."
107
+ )
108
+
109
+ shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
110
+ if "lm_head.weight" in state_dict:
111
+ shard_first_dim(state_dict, "lm_head.weight")
112
+ if "transformer.embeddings.position_embeddings.weight" in state_dict:
113
+ shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
114
+ for i in range(config.num_hidden_layers):
115
+ shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
116
+ shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
117
+ shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight")
118
+ if rank != 0:
119
+ state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias")
120
+ shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
121
+ shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
122
+ shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
123
+ if rank != 0:
124
+ state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias")
125
+ return state_dict
HybridTensor/models/llama.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaConfig, LlamaTokenizer
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from HybridTensor.models.create_sparse_model import GPTLMHeadModel
6
+ from HybridTensor.modules.SelectiveRouters import create_mlp_router_state_dict, create_attn_router_state_dict
7
+
8
+ # from flash_attn.models.gpt import GPTLMHeadModel
9
+ from transformers import AutoConfig, AutoTokenizer
10
+ from flash_attn.utils.pretrained import state_dict_from_pretrained
11
+
12
+ from flash_attn.models.llama import (
13
+ config_from_checkpoint,
14
+ inv_remap_state_dict_hf_llama,
15
+ llama_config_to_gpt2_config,
16
+ remap_state_dict_hf_llama,
17
+ remap_state_dict_meta_llama,
18
+ state_dicts_from_checkpoint,
19
+ )
20
+
21
+ class SparseConfig:
22
+ def __init__(self):
23
+ self.mlp_low_rank_dim = 1024
24
+ self.attn_low_rank_dim = 128
25
+ self.mlp_act_th = 0.5
26
+ self.attn_topk = 0.3
27
+
28
+ def build_dense_llama(model_name: str, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None, **kwargs):
29
+ config = llama_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, trust_remote_code=True))
30
+ config.use_flash_attn = True
31
+ config.fused_bias_fc = True
32
+ config.fused_mlp = False # We don't have fused GatedMLP yet
33
+ config.fused_dropout_add_ln = True
34
+ config.residual_in_fp32 = True
35
+ config.prenorm = True
36
+
37
+ state_dict = state_dict_from_pretrained(model_name, device='cpu', dtype=dtype)
38
+ state_dict = remap_state_dict_hf_llama(state_dict, config)
39
+
40
+ model = GPTLMHeadModel(config, device=device, dtype=dtype)
41
+ model.load_state_dict(state_dict, strict=True)
42
+ model.eval()
43
+
44
+ return model
45
+
46
+ def build_sparse_llama(args, model_name: str, attn_ckpt_dir: str, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None, **kwargs):
47
+ config = llama_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, trust_remote_code=True))
48
+ config.use_flash_attn = True
49
+ config.fused_bias_fc = True
50
+ config.fused_mlp = False # We don't have fused GatedMLP yet
51
+ config.fused_dropout_add_ln = True
52
+ config.residual_in_fp32 = True
53
+ config.prenorm = True
54
+
55
+ spconfig = SparseConfig()
56
+ spconfig.attn_topk = args.attn_topk
57
+ config.mlp_sparse = False
58
+ config.att_sparse = True
59
+
60
+ state_dict = state_dict_from_pretrained(model_name, device='cpu', dtype=dtype)
61
+ state_dict = remap_state_dict_hf_llama(state_dict, config)
62
+
63
+ model = GPTLMHeadModel(config, sp_config= spconfig, device=device, dtype=dtype)
64
+
65
+ if attn_ckpt_dir is not None:
66
+ attn_router_state_dict = create_attn_router_state_dict(attn_ckpt_dir)
67
+ merged_state_dict = {**state_dict, **attn_router_state_dict}
68
+
69
+ # TODO: Add code for tensor parallel state dict sharding
70
+
71
+ model.load_state_dict(merged_state_dict, strict=True)
72
+ model.eval()
73
+
74
+ return model
HybridTensor/models/opt.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from HybridTensor.utils.activations import OPT_MODELS
2
+ import torch
3
+ import math
4
+ from einops import rearrange
5
+
6
+ from flash_attn.utils.pretrained import state_dict_from_pretrained
7
+ from flash_attn.models.opt import remap_state_dict_hf_opt
8
+ from HybridTensor.modules.SelectiveRouters import create_mlp_router_state_dict, create_attn_router_state_dict
9
+ from HybridTensor.models.create_sparse_model import GPTLMHeadModel as GPTLMHeadModelSparse
10
+ from flash_attn.models.gpt import GPTLMHeadModel
11
+
12
+ from transformers.models.opt import OPTConfig
13
+ from flash_attn.models.opt import opt_config_to_gpt2_config
14
+
15
+ class SparseConfig:
16
+ def __init__(self):
17
+ self.mlp_low_rank_dim = 1024
18
+ self.attn_low_rank_dim = 128
19
+ self.mlp_act_th = 0.5
20
+ self.attn_topk = 0.3
21
+
22
+ def shard_state_dict_tp(state_dict, config, world_size, rank):
23
+ """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
24
+ with tensor parallel.
25
+ """
26
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
27
+ vocab_size = (
28
+ math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
29
+ )
30
+ assert vocab_size % world_size == 0
31
+ assert config.hidden_size % world_size == 0
32
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
33
+ assert inner_dim % world_size == 0
34
+
35
+ shared_state_dict = {}
36
+
37
+ def shard_first_dim(new, old, key):
38
+ x = old[key]
39
+ dim = x.shape[0] // world_size
40
+ new[key] = x[rank * dim : (rank + 1) * dim]
41
+
42
+ def shard_last_dim(new, old, key):
43
+ x = old[key]
44
+ dim = x.shape[-1] // world_size
45
+ new[key] = x[..., rank * dim : (rank + 1) * dim]
46
+
47
+ def shard_qkv_headdim(new, old, key):
48
+ x = rearrange(old[key], "(three d) ... -> three d ...", three=3)
49
+ dim = x.shape[1] // world_size
50
+ new[key] = rearrange(
51
+ x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..."
52
+ )
53
+
54
+ shard_first_dim(shared_state_dict, state_dict, "transformer.embeddings.word_embeddings.weight")
55
+
56
+ if "lm_head.weight" in state_dict:
57
+ shard_first_dim(shared_state_dict, state_dict, "lm_head.weight")
58
+ if "transformer.embeddings.position_embeddings.weight" in state_dict:
59
+ shard_last_dim(shared_state_dict, state_dict, "transformer.embeddings.position_embeddings.weight")
60
+
61
+ for i in range(config.num_hidden_layers):
62
+ # attention
63
+ shard_qkv_headdim(shared_state_dict, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
64
+ shard_qkv_headdim(shared_state_dict, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
65
+ shard_last_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight")
66
+
67
+ # mlp
68
+ shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
69
+ shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
70
+ shard_last_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
71
+
72
+ if rank == 0:
73
+ shared_state_dict[f"transformer.layers.{i}.mlp.fc2.bias"] = state_dict[f"transformer.layers.{i}.mlp.fc2.bias"]
74
+ shared_state_dict[f"transformer.layers.{i}.mixer.out_proj.bias"] = state_dict[f"transformer.layers.{i}.mixer.out_proj.bias"]
75
+
76
+ shared_state_dict[f"transformer.layers.{i}.norm1.weight"] = state_dict[f"transformer.layers.{i}.norm1.weight"]
77
+ shared_state_dict[f"transformer.layers.{i}.norm1.bias"] = state_dict[f"transformer.layers.{i}.norm1.bias"]
78
+ shared_state_dict[f"transformer.layers.{i}.norm2.weight"] = state_dict[f"transformer.layers.{i}.norm2.weight"]
79
+ shared_state_dict[f"transformer.layers.{i}.norm2.bias"] = state_dict[f"transformer.layers.{i}.norm2.bias"]
80
+
81
+ # routers
82
+
83
+ # mlp router
84
+ shared_state_dict[f"transformer.layers.{i}.mlp_router.fc1.weight"] = state_dict[f"transformer.layers.{i}.mlp_router.fc1.weight"]
85
+ shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp_router.fc2.weight")
86
+
87
+ # mha router
88
+ shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mha_router.linear1.weight")
89
+ shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mha_router.linear1.bias")
90
+
91
+ shared_state_dict[f"transformer.ln_f.weight"] = state_dict["transformer.ln_f.weight"]
92
+ shared_state_dict[f"transformer.ln_f.bias"] = state_dict["transformer.ln_f.bias"]
93
+
94
+ # shared_state_dict[f"transformer.ln_f.weight"] = state_dict["transformer.final_layer_norm.weight"]
95
+ # shared_state_dict[f"transformer.ln_f.bias"] = state_dict["transformer.final_layer_norm.bias"]
96
+
97
+ return shared_state_dict
98
+
99
+ '''
100
+ def shard_state_dict_tp(state_dict, config, world_size, rank):
101
+ """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
102
+ with tensor parallel.
103
+ """
104
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
105
+ vocab_size = (
106
+ math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
107
+ )
108
+ assert vocab_size % world_size == 0
109
+ assert config.hidden_size % world_size == 0
110
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
111
+ assert inner_dim % world_size == 0
112
+
113
+ def shard_first_dim(state_dict, key):
114
+ x = state_dict[key]
115
+ dim = x.shape[0] // world_size
116
+ state_dict[key] = x[rank * dim : (rank + 1) * dim]
117
+
118
+ def shard_last_dim(state_dict, key):
119
+ x = state_dict[key]
120
+ dim = x.shape[-1] // world_size
121
+ state_dict[key] = x[..., rank * dim : (rank + 1) * dim]
122
+
123
+ def shard_qkv_headdim(state_dict, key):
124
+ x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
125
+ dim = x.shape[1] // world_size
126
+ state_dict[key] = rearrange(
127
+ x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..."
128
+ )
129
+
130
+ shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
131
+ if "lm_head.weight" in state_dict:
132
+ shard_first_dim(state_dict, "lm_head.weight")
133
+ if "transformer.embeddings.position_embeddings.weight" in state_dict:
134
+ shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
135
+ for i in range(config.num_hidden_layers):
136
+ shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
137
+ shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
138
+ shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight")
139
+ if rank != 0:
140
+ state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias")
141
+ shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
142
+ shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
143
+ shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
144
+ if rank != 0:
145
+ state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias")
146
+ return state_dict
147
+
148
+
149
+ '''
150
+
151
+ def build_sparse_opt(args, model_name, mlp_ckpt_dir, attn_ckpt_dir, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None):
152
+ # dtype = torch.float16
153
+
154
+ config = OPTConfig.from_pretrained(model_name)
155
+ config = opt_config_to_gpt2_config(config)
156
+
157
+ if device in ('cpu', torch.device('cpu')):
158
+ config.fused_mlp = False
159
+ config.fused_dropout_add_ln = False
160
+ config.use_flash_attn = False
161
+ config.fused_bias_fc = False
162
+ else:
163
+ config.fused_mlp = True
164
+ config.fused_dropout_add_ln = True
165
+ config.use_flash_attn = True
166
+ config.fused_bias_fc = True
167
+ config.sequence_parallel = False
168
+
169
+ config.residual_in_fp32 = getattr(config, "prenorm", True)
170
+ config.pad_vocab_size_multiple = 8
171
+ config.mlp_sparse = True
172
+ config.att_sparse = True
173
+
174
+ config.use_heuristic = True
175
+ if config.use_heuristic:
176
+ print("Using pre-compiled heuristic")
177
+ else:
178
+ print("Compiling new heuristic during runtime")
179
+
180
+ spconfig = SparseConfig()
181
+ spconfig.mlp_act_th = 0.5 # sets the threshold for the MLP routers for all layers
182
+ spconfig.attn_topk = args.attn_topk # sets the topk for the attention routers for all layers
183
+
184
+ # build model
185
+ print("Bulding Model with sparse routers")
186
+ model_sparse = GPTLMHeadModelSparse(config = config, sp_config = spconfig, process_group = process_group, device = device, dtype=dtype)
187
+ # print(model_sparse)
188
+
189
+ # load pretrained weights into the sparse model
190
+ state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
191
+ state_dict = remap_state_dict_hf_opt(state_dict, config)
192
+
193
+ # load the routers into the model
194
+ if mlp_ckpt_dir is not None and attn_ckpt_dir is not None:
195
+ mlp_router_state_dict = create_mlp_router_state_dict(mlp_ckpt_dir)
196
+ attn_router_state_dict = create_attn_router_state_dict(attn_ckpt_dir)
197
+
198
+ # merge the state dict
199
+ merged_state_dict = {**state_dict, **mlp_router_state_dict, **attn_router_state_dict}
200
+
201
+ if process_group is not None:
202
+ merged_state_dict = shard_state_dict_tp(merged_state_dict, config, world_size, rank)
203
+
204
+ model_sparse.load_state_dict(merged_state_dict, strict=True)
205
+ else:
206
+ if process_group is not None:
207
+ state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
208
+ model_sparse.load_state_dict(state_dict, strict=False)
209
+
210
+ return model_sparse
211
+
212
+ def build_dense_opt(model_name, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None):
213
+ dtype = torch.float16
214
+
215
+ config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
216
+ config.use_flash_attn = True
217
+ config.fused_bias_fc = True
218
+ config.fused_mlp = True
219
+ # config.fused_dropout_add_ln = True
220
+ config.sequence_parallel = False
221
+ # Only prenorm supports residual_in_fp32
222
+ config.residual_in_fp32 = getattr(config, "prenorm", True)
223
+ config.pad_vocab_size_multiple = 8
224
+
225
+ # build model
226
+ print("Bulding Dense Model")
227
+ model = GPTLMHeadModel.from_pretrained(model_name, config, process_group = process_group, world_size = world_size, rank = rank, device=device, dtype=dtype)
228
+
229
+ return model
HybridTensor/modules/SelectiveBlock.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torchvision.ops import StochasticDepth
9
+
10
+ try:
11
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
12
+ except ImportError:
13
+ layer_norm_fn, RMSNorm = None, None
14
+
15
+ class SelectBlock(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim,
19
+ mixer_cls=None,
20
+ mlp_cls=None,
21
+ mlp_router=None,
22
+ mha_router=None,
23
+ norm_cls=nn.LayerNorm,
24
+ dropout_cls=nn.Dropout,
25
+ prenorm=True,
26
+ resid_dropout1=0.0,
27
+ resid_dropout2=0.0,
28
+ drop_path1=0.0,
29
+ drop_path2=0.0,
30
+ fused_dropout_add_ln=False,
31
+ return_residual=False,
32
+ residual_in_fp32=False,
33
+ sequence_parallel=False,
34
+ mark_shared_params=False,
35
+ ):
36
+ """
37
+ For prenorm=True, this Block has a slightly different structure compared to a regular
38
+ prenorm Transformer block.
39
+
40
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
41
+ Here we do: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, etc.
42
+
43
+ If you want to do concurrency with CUDA graphs, your shapes must remain fixed
44
+ (batch_size, seq_len, etc.) across captures and replays. Also avoid any operations
45
+ that cause dynamic shape changes or memory allocations.
46
+ """
47
+ super().__init__()
48
+ self.prenorm = prenorm
49
+ self.fused_dropout_add_ln = fused_dropout_add_ln
50
+ self.return_residual = return_residual
51
+ self.residual_in_fp32 = residual_in_fp32
52
+ if self.residual_in_fp32:
53
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
54
+
55
+ assert mixer_cls is not None and mlp_cls is not None, (
56
+ "mixer_cls and mlp_cls cannot be None in SelectBlock"
57
+ )
58
+
59
+ # MHA & MLP submodules
60
+ self.mixer = mixer_cls(dim)
61
+ self.dropout1 = dropout_cls(resid_dropout1)
62
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
63
+ self.norm1 = norm_cls(dim)
64
+ self.mlp = mlp_cls(dim)
65
+ self.total_neurons = self.mlp.fc1.weight.shape[0]
66
+
67
+ # Routers
68
+ if mlp_router is not None:
69
+ self.mlp_router = mlp_router(dim)
70
+ self.skip_attn_router = False
71
+ else:
72
+ self.mlp_router = None
73
+ self.skip_attn_router = True
74
+
75
+ if mha_router is not None:
76
+ self.mha_router = mha_router(dim)
77
+ else:
78
+ self.mha_router = None
79
+
80
+ if not isinstance(self.mlp, nn.Identity):
81
+ self.dropout2 = dropout_cls(resid_dropout2)
82
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
83
+ self.norm2 = norm_cls(dim)
84
+
85
+ if self.fused_dropout_add_ln:
86
+ assert layer_norm_fn is not None, "Triton layer_norm_fn not installed"
87
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(self.dropout1, nn.Dropout)
88
+
89
+ # Mark the norm parameters for sequence parallel / shared params if needed
90
+ if sequence_parallel:
91
+ for p in self.norm1.parameters():
92
+ p._sequence_parallel = True
93
+ if hasattr(self, "norm2"):
94
+ for p in self.norm2.parameters():
95
+ p._sequence_parallel = True
96
+ if mark_shared_params:
97
+ for p in self.norm1.parameters():
98
+ p._shared_params = True
99
+ if hasattr(self, "norm2"):
100
+ for p in self.norm2.parameters():
101
+ p._shared_params = True
102
+
103
+ self.mlp_topk = None
104
+ self.skip_mlp_router = False
105
+ self.skip_attn_router = False
106
+
107
+ # We'll use an extra stream for concurrency
108
+ self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0)
109
+ self.main_stream = torch.cuda.Stream(device="cuda", priority=-5)
110
+ # We'll record events to coordinate concurrency
111
+ self.mha_event = torch.cuda.Event(enable_timing=False, blocking=False)
112
+ self.mlp_event = torch.cuda.Event(enable_timing=False, blocking=False)
113
+
114
+ self.use_tensor_parallel = mark_shared_params
115
+
116
+ if self.use_tensor_parallel:
117
+ # save the stream and events in the mixer and mlp classes
118
+ self.mlp.router = self.mlp_router
119
+ self.mixer.router = self.mha_router
120
+
121
+ self.mlp_topk_layers = None # this will be a dictionary of layer_idx -> topk value
122
+ self.attn_topk_layers = None # this will be a dictionary of layer_idx -> topk value
123
+
124
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
125
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
126
+
127
+ def prefill_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_kwargs=None, mixer_subset=None):
128
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
129
+
130
+ if mixer_subset is not None:
131
+ residual = residual[:, mixer_subset]
132
+
133
+ if not isinstance(self.mlp, nn.Identity):
134
+ if not self.fused_dropout_add_ln:
135
+ dropped = self.drop_path2(self.dropout2(hidden_states))
136
+ if dropped.shape != residual.shape:
137
+ dropped = dropped.view(residual.shape)
138
+ residual = (dropped + residual) if residual is not None else dropped
139
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
140
+ if self.residual_in_fp32:
141
+ residual = residual.to(torch.float32)
142
+ else:
143
+ if self.drop_path2.p == 0 or not self.training:
144
+ rowscale2 = None
145
+ else:
146
+ rowscale2 = self.drop_path2(
147
+ torch.ones(
148
+ hidden_states.shape[:-1],
149
+ device=hidden_states.device,
150
+ dtype=hidden_states.dtype,
151
+ )
152
+ )
153
+ if hidden_states.shape != residual.shape:
154
+ hidden_states = hidden_states.view(residual.shape)
155
+ hidden_states, residual = layer_norm_fn(
156
+ hidden_states,
157
+ self.norm2.weight,
158
+ self.norm2.bias,
159
+ residual=residual,
160
+ eps=self.norm2.eps,
161
+ dropout_p=self.dropout2.p if self.training else 0.0,
162
+ rowscale=rowscale2,
163
+ prenorm=True,
164
+ residual_in_fp32=self.residual_in_fp32,
165
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
166
+ )
167
+ hidden_states = self.mlp(hidden_states)
168
+ return hidden_states, residual
169
+
170
+ def decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
171
+ """ Single GPU Decode Forward
172
+
173
+ Args:
174
+ hidden_states (Tensor): _description_
175
+ residual (Optional[Tensor], optional): _description_. Defaults to None.
176
+ mixer_subset (_type_, optional): _description_. Defaults to None.
177
+ """
178
+ curr_stream = torch.cuda.current_stream()
179
+
180
+ # We want to run MHA & mlp_router in parallel on different streams
181
+ router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
182
+ self.main_stream.wait_stream(curr_stream)
183
+ self.sparse_stream.wait_stream(curr_stream)
184
+ main_stream = self.main_stream
185
+
186
+ # if mlp_topk > th * total_neurons, skip mlp router
187
+
188
+ # if self.mlp_topk > 0.8 * self.total_neurons:
189
+ # self.skip_mlp_router = True
190
+ # else:
191
+ # self.skip_mlp_router = False
192
+
193
+ # [Sparse stream] mlp_router
194
+ if not self.skip_mlp_router:
195
+ with torch.cuda.stream(self.sparse_stream):
196
+ index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
197
+ self.sparse_stream.record_event(self.mlp_event)
198
+
199
+ # [Main stream] MHA
200
+ with torch.cuda.stream(main_stream):
201
+ batch_head_idx = self.mha_router._select_heads(router_inputs)
202
+ hidden_states = self.mixer(
203
+ hidden_states,
204
+ batch_head_idx=batch_head_idx,
205
+ **mixer_kwargs
206
+ )
207
+
208
+ main_stream.record_event(self.mha_event)
209
+
210
+ # Now we unify after both are done, then do the next steps
211
+ with torch.cuda.stream(main_stream):
212
+ # Wait on router & MHA
213
+ curr_stream.wait_stream(main_stream)
214
+ main_stream.wait_event(self.mha_event)
215
+
216
+ # normal residual / layernorm
217
+ if mixer_subset is not None:
218
+ residual = residual[:, mixer_subset]
219
+
220
+ if not isinstance(self.mlp, nn.Identity):
221
+ if not self.fused_dropout_add_ln:
222
+ dropped = self.drop_path2(self.dropout2(hidden_states))
223
+ residual = (dropped + residual) if residual is not None else dropped
224
+ hidden_states = self.norm2(
225
+ residual.to(dtype=self.norm2.weight.dtype)
226
+ )
227
+ if self.residual_in_fp32:
228
+ residual = residual.to(torch.float32)
229
+ else:
230
+ if self.drop_path2.p == 0 or not self.training:
231
+ rowscale2 = None
232
+ else:
233
+ rowscale2 = self.drop_path2(
234
+ torch.ones(
235
+ hidden_states.shape[:-1],
236
+ device=hidden_states.device,
237
+ dtype=hidden_states.dtype,
238
+ )
239
+ )
240
+ if hidden_states.shape != residual.shape:
241
+ hidden_states = hidden_states.view(residual.shape)
242
+ hidden_states, residual = layer_norm_fn(
243
+ hidden_states,
244
+ self.norm2.weight,
245
+ self.norm2.bias,
246
+ residual=residual,
247
+ eps=self.norm2.eps,
248
+ dropout_p=self.dropout2.p if self.training else 0.0,
249
+ rowscale=rowscale2,
250
+ prenorm=True,
251
+ residual_in_fp32=self.residual_in_fp32,
252
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
253
+ )
254
+
255
+ # hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
256
+ if self.skip_mlp_router:
257
+ hidden_states = self.mlp(hidden_states, index_vec=None)
258
+ else:
259
+ curr_stream.wait_stream(self.sparse_stream)
260
+ main_stream.wait_event(self.mlp_event)
261
+ hidden_states = self.mlp(hidden_states, index_vec=index_vec)
262
+ curr_stream.wait_stream(main_stream)
263
+ curr_stream.wait_stream(self.sparse_stream)
264
+
265
+ return hidden_states, residual
266
+
267
+ def tp_decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
268
+ """
269
+ Tensor Parallel Decode Forward
270
+
271
+ """
272
+
273
+ curr_stream = torch.cuda.current_stream()
274
+ self.sparse_stream.wait_stream(curr_stream)
275
+ # self.main_stream.wait_stream(curr_stream)
276
+
277
+ router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
278
+
279
+ if self.mlp_topk > 0.8 * self.total_neurons:
280
+ self.skip_mlp_router = True
281
+ else:
282
+ self.skip_mlp_router = False
283
+
284
+ # attention router is synchronous
285
+ batch_head_idx = self.mha_router._select_heads(router_inputs)
286
+
287
+ # mlp router is asynchronous
288
+ if not self.skip_mlp_router:
289
+ with torch.cuda.stream(self.sparse_stream):
290
+ index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
291
+ self.sparse_stream.record_event(self.mlp_event)
292
+
293
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs, batch_head_idx=batch_head_idx)
294
+
295
+ if mixer_subset is not None:
296
+ residual = residual[:, mixer_subset]
297
+
298
+ if not isinstance(self.mlp, nn.Identity):
299
+ if not self.fused_dropout_add_ln:
300
+ dropped = self.drop_path2(self.dropout2(hidden_states))
301
+ if dropped.shape != residual.shape:
302
+ dropped = dropped.view(residual.shape)
303
+ residual = (dropped + residual) if residual is not None else dropped
304
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
305
+ if self.residual_in_fp32:
306
+ residual = residual.to(torch.float32)
307
+ else:
308
+ if self.drop_path2.p == 0 or not self.training:
309
+ rowscale2 = None
310
+ else:
311
+ rowscale2 = self.drop_path2(
312
+ torch.ones(
313
+ hidden_states.shape[:-1],
314
+ device=hidden_states.device,
315
+ dtype=hidden_states.dtype,
316
+ )
317
+ )
318
+ if hidden_states.shape != residual.shape:
319
+ hidden_states = hidden_states.view(residual.shape)
320
+ hidden_states, residual = layer_norm_fn(
321
+ hidden_states,
322
+ self.norm2.weight,
323
+ self.norm2.bias,
324
+ residual=residual,
325
+ eps=self.norm2.eps,
326
+ dropout_p=self.dropout2.p if self.training else 0.0,
327
+ rowscale=rowscale2,
328
+ prenorm=True,
329
+ residual_in_fp32=self.residual_in_fp32,
330
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
331
+ )
332
+
333
+ # curr_stream.wait_stream(self.sparse_stream)
334
+ if self.skip_mlp_router:
335
+ hidden_states = self.mlp(hidden_states, index_vec=None)
336
+ else:
337
+ curr_stream.wait_event(self.mlp_event)
338
+ hidden_states = self.mlp(hidden_states, index_vec=index_vec)
339
+
340
+ return hidden_states, residual
341
+
342
+ def attn_sparse_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
343
+ """
344
+ Decode Forward with Sparse Attention Router
345
+ """
346
+
347
+ # We want to run MHA & mlp_router in parallel on different streams
348
+ router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
349
+
350
+ batch_head_idx = self.mha_router._select_heads(router_inputs)
351
+
352
+ # print(f"hidden_states shape: {hidden_states.shape}")
353
+ # print(f"hidden states: {hidden_states}")
354
+ hidden_states = self.mixer(hidden_states, batch_head_idx=batch_head_idx, **mixer_kwargs)
355
+
356
+ # normal residual / layernorm
357
+ if mixer_subset is not None:
358
+ residual = residual[:, mixer_subset]
359
+
360
+ if not isinstance(self.mlp, nn.Identity):
361
+ if not self.fused_dropout_add_ln:
362
+ dropped = self.drop_path2(self.dropout2(hidden_states))
363
+ residual = (dropped + residual) if residual is not None else dropped
364
+ hidden_states = self.norm2(
365
+ residual.to(dtype=self.norm2.weight.dtype)
366
+ )
367
+ if self.residual_in_fp32:
368
+ residual = residual.to(torch.float32)
369
+ else:
370
+ if self.drop_path2.p == 0 or not self.training:
371
+ rowscale2 = None
372
+ else:
373
+ rowscale2 = self.drop_path2(
374
+ torch.ones(hidden_states.shape[:-1], device=hidden_states.device, dtype=hidden_states.dtype,)
375
+ )
376
+ if hidden_states.shape != residual.shape:
377
+ hidden_states = hidden_states.view(residual.shape)
378
+ hidden_states, residual = layer_norm_fn(hidden_states, self.norm2.weight, self.norm2.bias, residual=residual,
379
+ eps=self.norm2.eps, dropout_p=self.dropout2.p if self.training else 0.0,
380
+ rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32,
381
+ is_rms_norm=isinstance(self.norm2, RMSNorm),)
382
+
383
+ # hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
384
+ hidden_states = self.mlp(hidden_states)
385
+
386
+ return hidden_states, residual
387
+
388
+ def mlp_sparse_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
389
+ """ Single GPU Decode Forward
390
+
391
+ Args:
392
+ hidden_states (Tensor): _description_
393
+ residual (Optional[Tensor], optional): _description_. Defaults to None.
394
+ mixer_subset (_type_, optional): _description_. Defaults to None.
395
+ """
396
+ curr_stream = torch.cuda.current_stream()
397
+
398
+ # We want to run MHA & mlp_router in parallel on different streams
399
+ router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
400
+ self.main_stream.wait_stream(curr_stream)
401
+ self.sparse_stream.wait_stream(curr_stream)
402
+ main_stream = self.main_stream
403
+
404
+ # if mlp_topk > th * total_neurons, skip mlp router
405
+
406
+ if self.mlp_topk > 0.8 * self.total_neurons:
407
+ self.skip_mlp_router = True
408
+ else:
409
+ self.skip_mlp_router = False
410
+
411
+ # [Sparse stream] mlp_router
412
+ if not self.skip_mlp_router:
413
+ with torch.cuda.stream(self.sparse_stream):
414
+ index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
415
+ self.sparse_stream.record_event(self.mlp_event)
416
+
417
+ # [Main stream] MHA
418
+ with torch.cuda.stream(main_stream):
419
+ # batch_head_idx = self.mha_router._select_heads(router_inputs)
420
+ hidden_states = self.mixer(
421
+ hidden_states,
422
+ batch_head_idx=None,
423
+ **mixer_kwargs
424
+ )
425
+
426
+ main_stream.record_event(self.mha_event)
427
+
428
+ # Now we unify after both are done, then do the next steps
429
+ with torch.cuda.stream(main_stream):
430
+ # Wait on router & MHA
431
+ curr_stream.wait_stream(main_stream)
432
+ main_stream.wait_event(self.mha_event)
433
+
434
+ # normal residual / layernorm
435
+ if mixer_subset is not None:
436
+ residual = residual[:, mixer_subset]
437
+
438
+ if not isinstance(self.mlp, nn.Identity):
439
+ if not self.fused_dropout_add_ln:
440
+ dropped = self.drop_path2(self.dropout2(hidden_states))
441
+ residual = (dropped + residual) if residual is not None else dropped
442
+ hidden_states = self.norm2(
443
+ residual.to(dtype=self.norm2.weight.dtype)
444
+ )
445
+ if self.residual_in_fp32:
446
+ residual = residual.to(torch.float32)
447
+ else:
448
+ if self.drop_path2.p == 0 or not self.training:
449
+ rowscale2 = None
450
+ else:
451
+ rowscale2 = self.drop_path2(
452
+ torch.ones(
453
+ hidden_states.shape[:-1],
454
+ device=hidden_states.device,
455
+ dtype=hidden_states.dtype,
456
+ )
457
+ )
458
+ if hidden_states.shape != residual.shape:
459
+ hidden_states = hidden_states.view(residual.shape)
460
+ hidden_states, residual = layer_norm_fn(
461
+ hidden_states,
462
+ self.norm2.weight,
463
+ self.norm2.bias,
464
+ residual=residual,
465
+ eps=self.norm2.eps,
466
+ dropout_p=self.dropout2.p if self.training else 0.0,
467
+ rowscale=rowscale2,
468
+ prenorm=True,
469
+ residual_in_fp32=self.residual_in_fp32,
470
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
471
+ )
472
+
473
+ # hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
474
+ if self.skip_mlp_router:
475
+ hidden_states = self.mlp(hidden_states, index_vec=None)
476
+ else:
477
+ curr_stream.wait_stream(self.sparse_stream)
478
+ main_stream.wait_event(self.mlp_event)
479
+ hidden_states = self.mlp(hidden_states, index_vec=index_vec)
480
+ curr_stream.wait_stream(main_stream)
481
+ curr_stream.wait_stream(self.sparse_stream)
482
+
483
+ return hidden_states, residual
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: Tensor,
488
+ residual: Optional[Tensor] = None,
489
+ mixer_subset=None,
490
+ mixer_kwargs=None,
491
+ mlp_topk=None,
492
+ attn_topk=None,
493
+ ):
494
+ """
495
+ This forward pass includes concurrency logic in the decode branch.
496
+ If you're capturing with a CUDA graph, the concurrency (two-stream usage) must be
497
+ inside the captured region so that the replay reproduces the parallel streams.
498
+ """
499
+
500
+ # simulation values
501
+ if mlp_topk is not None:
502
+ self.mlp_topk = mlp_topk
503
+
504
+ if attn_topk is not None:
505
+ self.mha_router.topk = attn_topk
506
+
507
+ if mixer_kwargs is None:
508
+ mixer_kwargs = {"inference_params": None}
509
+ else:
510
+ # Ensure 'inference_params' key exists
511
+ if "inference_params" not in mixer_kwargs:
512
+ mixer_kwargs["inference_params"] = None
513
+
514
+ if self.prenorm:
515
+ # --- 1) Prenorm’s dropout/add/layernorm
516
+ if not self.fused_dropout_add_ln:
517
+ dropped = self.drop_path1(self.dropout1(hidden_states))
518
+ residual = (dropped + residual) if residual is not None else dropped
519
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
520
+ if self.residual_in_fp32:
521
+ residual = residual.to(torch.float32)
522
+ else:
523
+ # fused dropout + add + layernorm
524
+ if self.drop_path1.p == 0 or not self.training:
525
+ rowscale1 = None
526
+ else:
527
+ rowscale1 = self.drop_path1(
528
+ torch.ones(
529
+ hidden_states.shape[:-1],
530
+ device=hidden_states.device,
531
+ dtype=hidden_states.dtype,
532
+ )
533
+ )
534
+ if residual is not None and hidden_states.shape != residual.shape:
535
+ hidden_states = hidden_states.view(residual.shape)
536
+ hidden_states, residual = layer_norm_fn(
537
+ hidden_states,
538
+ self.norm1.weight,
539
+ self.norm1.bias,
540
+ residual=residual,
541
+ eps=self.norm1.eps,
542
+ dropout_p=self.dropout1.p if self.training else 0.0,
543
+ rowscale=rowscale1,
544
+ prenorm=True,
545
+ residual_in_fp32=self.residual_in_fp32,
546
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
547
+ )
548
+
549
+ if mixer_subset is not None:
550
+ mixer_kwargs["mixer_subset"] = mixer_subset
551
+
552
+ # Check if we are in the prefill or decode stage
553
+ prefill_stage = (
554
+ mixer_kwargs["inference_params"] is None
555
+ or mixer_kwargs["inference_params"].seqlen_offset == 0
556
+ )
557
+
558
+ if prefill_stage:
559
+ # --- 2) Prefill stage (no concurrency): just do normal forward
560
+ hidden_states, residual = self.prefill_forward(hidden_states, residual, mixer_kwargs, mixer_subset)
561
+
562
+ else:
563
+ # --- 3) Decode stage:
564
+ if self.mlp_router is None:
565
+ # decode stage with only attention router, works with both single gpu and tensor parallel
566
+ hidden_states, residual = self.attn_sparse_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
567
+ else:
568
+ if not self.use_tensor_parallel:
569
+ if self.mha_router is None:
570
+ # decode stage with mlp routers (opt models and single gpu)
571
+ hidden_states, residual = self.mlp_sparse_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
572
+ else:
573
+ # decode stage with mlp and attention routers (opt models and single gpu)
574
+ hidden_states, residual = self.decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
575
+ else:
576
+ # uses both mlp and attention routers in tensor parallel
577
+ hidden_states, residual = self.tp_decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
578
+
579
+ return hidden_states, residual
580
+
581
+ else:
582
+ # post-norm architecture not implemented here
583
+ raise NotImplementedError
584
+
585
+
586
+ # class SelectBlock(nn.Module):
587
+ # def __init__(
588
+ # self,
589
+ # dim,
590
+ # mixer_cls=None,
591
+ # mlp_cls=None,
592
+ # mlp_router=None,
593
+ # mha_router=None,
594
+ # norm_cls=nn.LayerNorm,
595
+ # dropout_cls=nn.Dropout,
596
+ # prenorm=True,
597
+ # resid_dropout1=0.0,
598
+ # resid_dropout2=0.0,
599
+ # drop_path1=0.0,
600
+ # drop_path2=0.0,
601
+ # fused_dropout_add_ln=False,
602
+ # return_residual=False,
603
+ # residual_in_fp32=False,
604
+ # sequence_parallel=False,
605
+ # mark_shared_params=False,
606
+ # ):
607
+ # """
608
+ # For prenorm=True, this Block has a slightly different structure compared to a regular
609
+ # prenorm Transformer block.
610
+
611
+ # The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
612
+ # Here we do: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, etc.
613
+
614
+ # If you want to do concurrency with CUDA graphs, your shapes must remain fixed
615
+ # (batch_size, seq_len, etc.) across captures and replays. Also avoid any operations
616
+ # that cause dynamic shape changes or memory allocations.
617
+ # """
618
+ # super().__init__()
619
+ # self.prenorm = prenorm
620
+ # self.fused_dropout_add_ln = fused_dropout_add_ln
621
+ # self.return_residual = return_residual
622
+ # self.residual_in_fp32 = residual_in_fp32
623
+ # if self.residual_in_fp32:
624
+ # assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
625
+
626
+ # assert mixer_cls is not None and mlp_cls is not None, (
627
+ # "mixer_cls and mlp_cls cannot be None in SelectBlock"
628
+ # )
629
+
630
+ # # MHA & MLP submodules
631
+ # self.mixer = mixer_cls(dim)
632
+ # self.dropout1 = dropout_cls(resid_dropout1)
633
+ # self.drop_path1 = StochasticDepth(drop_path1, mode="row")
634
+ # self.norm1 = norm_cls(dim)
635
+ # self.mlp = mlp_cls(dim)
636
+
637
+ # # Routers
638
+ # self.mlp_router = mlp_router(dim)
639
+ # self.mha_router = mha_router(dim)
640
+
641
+ # if not isinstance(self.mlp, nn.Identity):
642
+ # self.dropout2 = dropout_cls(resid_dropout2)
643
+ # self.drop_path2 = StochasticDepth(drop_path2, mode="row")
644
+ # self.norm2 = norm_cls(dim)
645
+
646
+ # if self.fused_dropout_add_ln:
647
+ # assert layer_norm_fn is not None, "Triton layer_norm_fn not installed"
648
+ # assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(self.dropout1, nn.Dropout)
649
+
650
+ # # Mark the norm parameters for sequence parallel / shared params if needed
651
+ # if sequence_parallel:
652
+ # for p in self.norm1.parameters():
653
+ # p._sequence_parallel = True
654
+ # if hasattr(self, "norm2"):
655
+ # for p in self.norm2.parameters():
656
+ # p._sequence_parallel = True
657
+ # if mark_shared_params:
658
+ # for p in self.norm1.parameters():
659
+ # p._shared_params = True
660
+ # if hasattr(self, "norm2"):
661
+ # for p in self.norm2.parameters():
662
+ # p._shared_params = True
663
+
664
+ # self.mlp_topk = None
665
+ # self.skip_mlp_router = False
666
+ # self.skip_attn_router = False
667
+
668
+ # # We'll use an extra stream for concurrency
669
+ # self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0)
670
+ # self.main_stream = torch.cuda.Stream(device="cuda", priority=-5)
671
+ # # We'll record events to coordinate concurrency
672
+ # self.mha_event = torch.cuda.Event(enable_timing=False, blocking=False)
673
+ # self.mlp_event = torch.cuda.Event(enable_timing=False, blocking=False)
674
+
675
+ # self.use_tensor_parallel = mark_shared_params
676
+
677
+ # if self.use_tensor_parallel:
678
+ # # TODO: save the routers in the mixer and mlp classes
679
+ # # save the stream and events in the mixer and mlp classes
680
+ # pass
681
+
682
+ # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
683
+ # return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
684
+
685
+ # def prefill_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_kwargs=None, mixer_subset=None):
686
+ # hidden_states = self.mixer(hidden_states, **mixer_kwargs)
687
+
688
+ # if mixer_subset is not None:
689
+ # residual = residual[:, mixer_subset]
690
+
691
+ # if not isinstance(self.mlp, nn.Identity):
692
+ # if not self.fused_dropout_add_ln:
693
+ # dropped = self.drop_path2(self.dropout2(hidden_states))
694
+ # if dropped.shape != residual.shape:
695
+ # dropped = dropped.view(residual.shape)
696
+ # residual = (dropped + residual) if residual is not None else dropped
697
+ # hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
698
+ # if self.residual_in_fp32:
699
+ # residual = residual.to(torch.float32)
700
+ # else:
701
+ # if self.drop_path2.p == 0 or not self.training:
702
+ # rowscale2 = None
703
+ # else:
704
+ # rowscale2 = self.drop_path2(
705
+ # torch.ones(
706
+ # hidden_states.shape[:-1],
707
+ # device=hidden_states.device,
708
+ # dtype=hidden_states.dtype,
709
+ # )
710
+ # )
711
+ # if hidden_states.shape != residual.shape:
712
+ # hidden_states = hidden_states.view(residual.shape)
713
+ # hidden_states, residual = layer_norm_fn(
714
+ # hidden_states,
715
+ # self.norm2.weight,
716
+ # self.norm2.bias,
717
+ # residual=residual,
718
+ # eps=self.norm2.eps,
719
+ # dropout_p=self.dropout2.p if self.training else 0.0,
720
+ # rowscale=rowscale2,
721
+ # prenorm=True,
722
+ # residual_in_fp32=self.residual_in_fp32,
723
+ # is_rms_norm=isinstance(self.norm2, RMSNorm),
724
+ # )
725
+ # hidden_states = self.mlp(hidden_states)
726
+ # return hidden_states, residual
727
+
728
+ # def decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
729
+ # """ Single GPU Decode Forward
730
+
731
+ # Args:
732
+ # hidden_states (Tensor): _description_
733
+ # residual (Optional[Tensor], optional): _description_. Defaults to None.
734
+ # mixer_subset (_type_, optional): _description_. Defaults to None.
735
+ # """
736
+ # curr_stream = torch.cuda.current_stream()
737
+
738
+ # # We want to run MHA & mlp_router in parallel on different streams
739
+ # router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
740
+ # self.main_stream.wait_stream(curr_stream)
741
+ # self.sparse_stream.wait_stream(curr_stream)
742
+
743
+ # # We'll do MHA on the "main_stream" and the router on "sparse_stream"
744
+ # main_stream = self.main_stream
745
+ # # In a captured region, each 'with torch.cuda.stream(...)' block
746
+ # # is replayed in concurrency. The shape must remain consistent.
747
+
748
+ # # [Sparse stream] mlp_router
749
+ # if not self.skip_mlp_router:
750
+ # with torch.cuda.stream(self.sparse_stream): # <-- CHANGED
751
+ # # index_size, index_vec = self.mlp_router._select_neurons_cuda_safe(router_inputs) # need to fix this; make CUDA Graph safe
752
+ # # vec = self.mlp_router(router_inputs)
753
+ # index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
754
+ # self.sparse_stream.record_event(self.mlp_event)
755
+
756
+ # # [Main stream] MHA
757
+ # with torch.cuda.stream(main_stream): # <-- CHANGED
758
+ # batch_head_idx = self.mha_router._select_heads(router_inputs)
759
+ # hidden_states = self.mixer(
760
+ # hidden_states,
761
+ # batch_head_idx=batch_head_idx,
762
+ # # batch_head_idx=None,
763
+ # **mixer_kwargs
764
+ # )
765
+ # main_stream.record_event(self.mha_event)
766
+
767
+ # # Now we unify after both are done, then do the next steps
768
+ # with torch.cuda.stream(main_stream): # <-- CHANGED
769
+ # # Wait on router & MHA
770
+ # curr_stream.wait_stream(main_stream)
771
+ # main_stream.wait_event(self.mha_event)
772
+
773
+ # # normal residual / layernorm
774
+ # if mixer_subset is not None:
775
+ # residual = residual[:, mixer_subset]
776
+
777
+ # if not isinstance(self.mlp, nn.Identity):
778
+ # if not self.fused_dropout_add_ln:
779
+ # dropped = self.drop_path2(self.dropout2(hidden_states))
780
+ # residual = (dropped + residual) if residual is not None else dropped
781
+ # hidden_states = self.norm2(
782
+ # residual.to(dtype=self.norm2.weight.dtype)
783
+ # )
784
+ # if self.residual_in_fp32:
785
+ # residual = residual.to(torch.float32)
786
+ # else:
787
+ # if self.drop_path2.p == 0 or not self.training:
788
+ # rowscale2 = None
789
+ # else:
790
+ # rowscale2 = self.drop_path2(
791
+ # torch.ones(
792
+ # hidden_states.shape[:-1],
793
+ # device=hidden_states.device,
794
+ # dtype=hidden_states.dtype,
795
+ # )
796
+ # )
797
+ # if hidden_states.shape != residual.shape:
798
+ # hidden_states = hidden_states.view(residual.shape)
799
+ # hidden_states, residual = layer_norm_fn(
800
+ # hidden_states,
801
+ # self.norm2.weight,
802
+ # self.norm2.bias,
803
+ # residual=residual,
804
+ # eps=self.norm2.eps,
805
+ # dropout_p=self.dropout2.p if self.training else 0.0,
806
+ # rowscale=rowscale2,
807
+ # prenorm=True,
808
+ # residual_in_fp32=self.residual_in_fp32,
809
+ # is_rms_norm=isinstance(self.norm2, RMSNorm),
810
+ # )
811
+
812
+ # # Finally do MLP with the router's index vector
813
+ # curr_stream.wait_stream(self.sparse_stream)
814
+ # main_stream.wait_event(self.mlp_event)
815
+
816
+ # # hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
817
+ # if self.skip_mlp_router:
818
+ # hidden_states = self.mlp(hidden_states, index_vec=None)
819
+ # else:
820
+ # hidden_states = self.mlp(hidden_states, index_vec=index_vec)
821
+ # curr_stream.wait_stream(main_stream)
822
+ # curr_stream.wait_stream(self.sparse_stream)
823
+
824
+ # return hidden_states, residual
825
+
826
+ # def tp_decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
827
+ # """
828
+ # Tensor Parallel Decode Forward
829
+
830
+ # Args:
831
+ # hidden_states (Tensor): _description_
832
+ # residual (Optional[Tensor], optional): _description_. Defaults to None.
833
+ # mixer_subset (_type_, optional): _description_. Defaults to None.
834
+ # """
835
+ # # TODO: need to add routing
836
+
837
+ # hidden_states = self.mixer(hidden_states, **mixer_kwargs)
838
+
839
+ # if mixer_subset is not None:
840
+ # residual = residual[:, mixer_subset]
841
+
842
+ # if not isinstance(self.mlp, nn.Identity):
843
+ # if not self.fused_dropout_add_ln:
844
+ # dropped = self.drop_path2(self.dropout2(hidden_states))
845
+ # if dropped.shape != residual.shape:
846
+ # dropped = dropped.view(residual.shape)
847
+ # residual = (dropped + residual) if residual is not None else dropped
848
+ # hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
849
+ # if self.residual_in_fp32:
850
+ # residual = residual.to(torch.float32)
851
+ # else:
852
+ # if self.drop_path2.p == 0 or not self.training:
853
+ # rowscale2 = None
854
+ # else:
855
+ # rowscale2 = self.drop_path2(
856
+ # torch.ones(
857
+ # hidden_states.shape[:-1],
858
+ # device=hidden_states.device,
859
+ # dtype=hidden_states.dtype,
860
+ # )
861
+ # )
862
+ # if hidden_states.shape != residual.shape:
863
+ # hidden_states = hidden_states.view(residual.shape)
864
+ # hidden_states, residual = layer_norm_fn(
865
+ # hidden_states,
866
+ # self.norm2.weight,
867
+ # self.norm2.bias,
868
+ # residual=residual,
869
+ # eps=self.norm2.eps,
870
+ # dropout_p=self.dropout2.p if self.training else 0.0,
871
+ # rowscale=rowscale2,
872
+ # prenorm=True,
873
+ # residual_in_fp32=self.residual_in_fp32,
874
+ # is_rms_norm=isinstance(self.norm2, RMSNorm),
875
+ # )
876
+ # hidden_states = self.mlp(hidden_states)
877
+ # return hidden_states, residual
878
+
879
+ # def forward(
880
+ # self,
881
+ # hidden_states: Tensor,
882
+ # residual: Optional[Tensor] = None,
883
+ # mixer_subset=None,
884
+ # mixer_kwargs=None,
885
+ # ):
886
+ # """
887
+ # This forward pass includes concurrency logic in the decode branch.
888
+ # If you're capturing with a CUDA graph, the concurrency (two-stream usage) must be
889
+ # inside the captured region so that the replay reproduces the parallel streams.
890
+ # """
891
+
892
+
893
+ # if mixer_kwargs is None:
894
+ # mixer_kwargs = {"inference_params": None}
895
+ # else:
896
+ # # Ensure 'inference_params' key exists
897
+ # if "inference_params" not in mixer_kwargs:
898
+ # mixer_kwargs["inference_params"] = None
899
+
900
+ # if self.prenorm:
901
+ # # --- 1) Prenorm’s dropout/add/layernorm
902
+ # if not self.fused_dropout_add_ln:
903
+ # dropped = self.drop_path1(self.dropout1(hidden_states))
904
+ # residual = (dropped + residual) if residual is not None else dropped
905
+ # hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
906
+ # if self.residual_in_fp32:
907
+ # residual = residual.to(torch.float32)
908
+ # else:
909
+ # # fused dropout + add + layernorm
910
+ # if self.drop_path1.p == 0 or not self.training:
911
+ # rowscale1 = None
912
+ # else:
913
+ # rowscale1 = self.drop_path1(
914
+ # torch.ones(
915
+ # hidden_states.shape[:-1],
916
+ # device=hidden_states.device,
917
+ # dtype=hidden_states.dtype,
918
+ # )
919
+ # )
920
+ # if residual is not None and hidden_states.shape != residual.shape:
921
+ # hidden_states = hidden_states.view(residual.shape)
922
+ # hidden_states, residual = layer_norm_fn(
923
+ # hidden_states,
924
+ # self.norm1.weight,
925
+ # self.norm1.bias,
926
+ # residual=residual,
927
+ # eps=self.norm1.eps,
928
+ # dropout_p=self.dropout1.p if self.training else 0.0,
929
+ # rowscale=rowscale1,
930
+ # prenorm=True,
931
+ # residual_in_fp32=self.residual_in_fp32,
932
+ # is_rms_norm=isinstance(self.norm1, RMSNorm),
933
+ # )
934
+
935
+ # if mixer_subset is not None:
936
+ # mixer_kwargs["mixer_subset"] = mixer_subset
937
+
938
+ # # Check if we are in the prefill or decode stage
939
+ # prefill_stage = (
940
+ # mixer_kwargs["inference_params"] is None
941
+ # or mixer_kwargs["inference_params"].seqlen_offset == 0
942
+ # )
943
+
944
+ # if prefill_stage:
945
+ # # --- 2) Prefill stage (no concurrency): just do normal forward
946
+ # hidden_states, residual = self.prefill_forward(hidden_states, residual, mixer_kwargs, mixer_subset)
947
+
948
+ # else:
949
+ # # # --- 3) Decode stage:
950
+ # if not self.use_tensor_parallel:
951
+ # hidden_states, residual = self.decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
952
+ # else:
953
+ # # routing is slightly different in tensor parallel; we overlap the router with allreduce
954
+ # hidden_states, residual = self.tp_decode_forward(hidden_states, residual, mixer_subset)
955
+
956
+ # return hidden_states, residual
957
+
958
+ # else:
959
+ # # post-norm architecture not implemented here
960
+ # raise NotImplementedError
HybridTensor/modules/SelectiveMHA.py ADDED
@@ -0,0 +1,1579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+
8
+ from flash_attn.utils.distributed import get_dim_for_local_rank
9
+ from flash_attn.utils.distributed import all_reduce
10
+
11
+ try:
12
+ from flash_attn import (
13
+ flash_attn_kvpacked_func,
14
+ flash_attn_qkvpacked_func,
15
+ flash_attn_varlen_kvpacked_func,
16
+ flash_attn_varlen_qkvpacked_func,
17
+ flash_attn_with_kvcache,
18
+ )
19
+ except ImportError:
20
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
21
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
22
+ flash_attn_with_kvcache = None
23
+
24
+ try:
25
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear, fused_dense_func
26
+ except ImportError:
27
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
28
+
29
+ try:
30
+ from flash_attn.layers.rotary import RotaryEmbedding
31
+ except ImportError:
32
+ RotaryEmbedding = None
33
+
34
+ from flash_attn.modules.mha import SelfAttention, FlashSelfAttention, LinearResidual, FlashCrossAttention, CrossAttention
35
+ from flash_attn.modules.mha import get_alibi_slopes #, _update_kv_cache
36
+ from flash_attn.utils.generation import InferenceParams
37
+
38
+ # from HybridTensor.modules.references.mha_dejavu import ParallelTracker # use this in the full implementation
39
+ # from HybridTensor.modules.references.mha_dejavu import ParallelMHASparseAttMlp
40
+ # from HybridTensor.triton.references.attention_proj_sparse import qkv_proj_sparse
41
+ # from HybridTensor.triton.select_attn import select_attn
42
+ # from HybridTensor.triton.select_attn_64b_kernel import select_attn
43
+ from HybridTensor.triton.attn_interface import flash_attn_with_kvcache_triton
44
+ from HybridTensor.triton.select_attn_v1 import select_attn
45
+ from HybridTensor.utils.utils import arg_parser, generate_BH_index, generate_random_BH_index
46
+ from HybridTensor.utils.profiling import cuda_profiler
47
+
48
+
49
+ class MHARouter(torch.nn.Module):
50
+ def __init__(self, embed_dim, low_rank_dim = None, out_dim = None, top_k = 0.5, device = None, dtype = None):
51
+ super(MHARouter, self).__init__()
52
+ factory_kwargs = {"device": device, "dtype": dtype}
53
+ self.model_dim = embed_dim
54
+ self.num_heads = out_dim
55
+ self.topk = top_k
56
+
57
+ self.linear1 = torch.nn.Linear(embed_dim, out_dim, bias = True, **factory_kwargs)
58
+
59
+ def forward(self, x):
60
+ out = self.linear1(x)
61
+ return out
62
+
63
+ def _select_heads(self, x, topk = None):
64
+ if topk is None:
65
+ topk = int(self.topk * self.num_heads)
66
+ else:
67
+ topk = int(self.num_heads * topk)
68
+ head_scores = self.forward(x)
69
+ _, selected_heads = torch.topk(head_scores, topk, dim=1)
70
+
71
+ return selected_heads
72
+
73
+ class ParallelMHARouter(torch.nn.Module):
74
+ def __init__(self, embed_dim, low_rank_dim, out_dim, top_k, process_group, sequence_parallel=False, device = None, dtype = None):
75
+ super(ParallelMHARouter, self).__init__()
76
+ factory_kwargs = {"device": device, "dtype": dtype}
77
+ self.model_dim = embed_dim
78
+ self.num_heads = out_dim
79
+ self.topk = top_k
80
+ world_size = torch.distributed.get_world_size(process_group)
81
+ self.local_heads = out_dim // world_size
82
+
83
+ self.linear1 = ColumnParallelLinear(
84
+ embed_dim,
85
+ out_dim,
86
+ process_group,
87
+ bias=True,
88
+ sequence_parallel=sequence_parallel,
89
+ **factory_kwargs,
90
+ )
91
+
92
+ def forward(self, x):
93
+ out = self.linear1(x)
94
+ return out
95
+
96
+ def _select_heads(self, x, topk = None):
97
+ if topk is None:
98
+ topk = int(self.topk * self.local_heads)
99
+ else:
100
+ topk = int(self.local_heads * topk)
101
+ head_scores = self.forward(x)
102
+ # head_scores = head_scores.squeeze(1)
103
+ # print(f"Head Scores.shape: {head_scores.shape}")
104
+ _, selected_heads = torch.topk(head_scores, topk, dim=1)
105
+ # print(f"Selected Heads: {selected_heads}")
106
+ return selected_heads
107
+
108
+ def _update_kv_cache(kv, inference_params, layer_idx):
109
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
110
+ # Pre-allocate memory for key-values for inference.
111
+ num_heads, head_dim = kv.shape[-2:]
112
+ if layer_idx not in inference_params.key_value_memory_dict:
113
+ kv_cache = torch.empty(
114
+ inference_params.max_batch_size,
115
+ inference_params.max_seqlen,
116
+ 2,
117
+ num_heads,
118
+ head_dim,
119
+ dtype=kv.dtype,
120
+ device=kv.device,
121
+ )
122
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
123
+ else:
124
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
125
+ # Adjust key and value for inference
126
+ batch_start = inference_params.batch_size_offset
127
+ batch_end = batch_start + kv.shape[0]
128
+ sequence_start = inference_params.seqlen_offset
129
+ sequence_end = sequence_start + kv.shape[1]
130
+ assert batch_end <= kv_cache.shape[0]
131
+ assert sequence_end <= kv_cache.shape[1]
132
+ assert kv_cache is not None
133
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
134
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
135
+
136
+ class SMHA(nn.Module):
137
+ """Multi-head self-attention and cross-attention with Triton decode kernels + Selective Attention"""
138
+
139
+ def __init__(
140
+ self,
141
+ embed_dim,
142
+ num_heads,
143
+ num_heads_kv=None,
144
+ cross_attn=False,
145
+ qkv_proj_bias=True,
146
+ out_proj_bias=True,
147
+ dropout=0.0,
148
+ softmax_scale=None,
149
+ causal=False,
150
+ layer_idx=None,
151
+ dwconv=False,
152
+ rotary_emb_dim=0,
153
+ rotary_emb_base=10000.0,
154
+ rotary_emb_scale_base=None,
155
+ rotary_emb_interleaved=False,
156
+ use_alibi=False,
157
+ window_size=(-1, -1),
158
+ fused_bias_fc=False,
159
+ use_flash_attn=False,
160
+ return_residual=False,
161
+ checkpointing=False,
162
+ use_triton=True,
163
+ device=None,
164
+ dtype=None,
165
+ ) -> None:
166
+ """
167
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
168
+ return_residual: whether to return the input x along with the output. This is for
169
+ performance reason: for post-norm architecture, returning the input allows us
170
+ to fuse the backward of nn.Linear with the residual connection.
171
+ """
172
+ factory_kwargs = {"device": device, "dtype": dtype}
173
+ super().__init__()
174
+ self.embed_dim = embed_dim
175
+ self.cross_attn = cross_attn
176
+ self.causal = causal
177
+ self.layer_idx = layer_idx
178
+ self.dwconv = dwconv
179
+ self.rotary_emb_dim = rotary_emb_dim
180
+ self.use_flash_attn = use_flash_attn
181
+ self.return_residual = return_residual
182
+ self.checkpointing = checkpointing
183
+ self.use_triton = use_triton
184
+ if use_alibi:
185
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
186
+ alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
187
+ else:
188
+ alibi_slopes = None
189
+ if window_size != (-1, -1):
190
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
191
+
192
+ self.num_heads = num_heads
193
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
194
+ assert (
195
+ self.num_heads % self.num_heads_kv == 0
196
+ ), "num_heads must be divisible by num_heads_kv"
197
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
198
+ self.head_dim = self.embed_dim // num_heads
199
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
200
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
201
+
202
+ if self.rotary_emb_dim > 0:
203
+ assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
204
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
205
+ self.rotary_emb = RotaryEmbedding(
206
+ self.rotary_emb_dim,
207
+ base=rotary_emb_base,
208
+ scale_base=rotary_emb_scale_base,
209
+ interleaved=rotary_emb_interleaved,
210
+ device=device,
211
+ )
212
+
213
+ if fused_bias_fc and FusedDense is None:
214
+ raise ImportError("fused_dense is not installed")
215
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
216
+ linear_resid_cls = (
217
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
218
+ )
219
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
220
+ inner_attn_cls = (
221
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
222
+ if use_flash_attn
223
+ else SelfAttention
224
+ )
225
+ inner_cross_attn_cls = (
226
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
227
+ if use_flash_attn
228
+ else CrossAttention
229
+ )
230
+ if not self.cross_attn:
231
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
232
+ else:
233
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
234
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
235
+ if self.dwconv:
236
+ if self.num_heads_kv == self.num_heads:
237
+ self.dwconv_qkv = nn.Conv1d(
238
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
239
+ )
240
+ else:
241
+ self.dwconv_q = nn.Conv1d(
242
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
243
+ )
244
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
245
+ self.inner_attn = inner_attn_cls(
246
+ causal=causal,
247
+ softmax_scale=softmax_scale,
248
+ attention_dropout=dropout,
249
+ )
250
+ self.inner_cross_attn = inner_cross_attn_cls(
251
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
252
+ )
253
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
254
+
255
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
256
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
257
+ device = self.out_proj.weight.device
258
+ return torch.empty(
259
+ batch_size,
260
+ max_seqlen,
261
+ 2,
262
+ self.num_heads_kv,
263
+ self.head_dim,
264
+ dtype=dtype,
265
+ device=device,
266
+ )
267
+
268
+ def _update_kv_cache(self, kv, inference_params):
269
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
270
+ assert not self.dwconv, "Generation does not support dwconv yet"
271
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
272
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
273
+
274
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
275
+ """
276
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
277
+ q: (batch_size, seqlen_q, nheads, head_dim)
278
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
279
+ """
280
+ assert inference_params is not None and inference_params.seqlen_offset > 0
281
+ assert self.use_flash_attn
282
+ if self.rotary_emb_dim > 0:
283
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
284
+ self.rotary_emb._update_cos_sin_cache(
285
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
286
+ )
287
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
288
+ else:
289
+ rotary_cos, rotary_sin = None, None
290
+ batch = q.shape[0]
291
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
292
+ cache_seqlens = (
293
+ inference_params.lengths_per_sample[:batch]
294
+ if inference_params.lengths_per_sample is not None
295
+ else inference_params.seqlen_offset
296
+ )
297
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
298
+
299
+ context = flash_attn_with_kvcache(
300
+ q,
301
+ kv_cache[:, :, 0],
302
+ kv_cache[:, :, 1],
303
+ kv[:, :, 0],
304
+ kv[:, :, 1],
305
+ rotary_cos=rotary_cos,
306
+ rotary_sin=rotary_sin,
307
+ cache_seqlens=cache_seqlens,
308
+ softmax_scale=self.inner_cross_attn.softmax_scale,
309
+ causal=self.inner_cross_attn.causal,
310
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
311
+ alibi_slopes=alibi_slopes,
312
+ )
313
+ return context
314
+
315
+
316
+ def _update_kvcache_attention_triton(self, q, kv, inference_params, batch_head_idx=None):
317
+ """
318
+ The rotary embeddings have to be applied before calling this function. The KV cache is update here.
319
+ q: (batch_size, seqlen_q, nheads, head_dim)
320
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
321
+ """
322
+ if (
323
+ inference_params.seqlen_offset == 0
324
+ or flash_attn_with_kvcache is None
325
+ or not self.use_flash_attn
326
+ ):
327
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
328
+ kv = self._update_kv_cache(kv, inference_params)
329
+ return self.inner_cross_attn(q, kv)
330
+ else:
331
+ batch = q.shape[0]
332
+ kv_cache = self._update_kv_cache(kv, inference_params)
333
+
334
+ cache_seqlens = (
335
+ inference_params.lengths_per_sample[:batch]
336
+ if inference_params.lengths_per_sample is not None
337
+ else inference_params.seqlen_offset
338
+ )
339
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
340
+
341
+ context = flash_attn_with_kvcache_triton(
342
+ q,
343
+ kv_cache[:, :, 0],
344
+ kv_cache[:, :, 1],
345
+ None, # kv[:, :, 0],
346
+ None, #kv[:, :, 1],
347
+ rotary_cos=None,
348
+ rotary_sin=None,
349
+ cache_seqlens=cache_seqlens,
350
+ softmax_scale=self.inner_cross_attn.softmax_scale,
351
+ causal=self.inner_cross_attn.causal,
352
+ rotary_interleaved= False, #self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
353
+ alibi_slopes=alibi_slopes,
354
+ batch_head_idx=batch_head_idx,
355
+ )
356
+ return context
357
+
358
+ def _update_kvcache_attention(self, q, kv, inference_params):
359
+ """Write kv to inference_params, then do attention"""
360
+ if (
361
+ inference_params.seqlen_offset == 0
362
+ or flash_attn_with_kvcache is None
363
+ or not self.use_flash_attn
364
+ ):
365
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
366
+ kv = self._update_kv_cache(kv, inference_params)
367
+ return self.inner_cross_attn(q, kv)
368
+ else:
369
+ batch = q.shape[0]
370
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
371
+ cache_seqlens = (
372
+ inference_params.lengths_per_sample[:batch]
373
+ if inference_params.lengths_per_sample is not None
374
+ else inference_params.seqlen_offset
375
+ )
376
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
377
+ return flash_attn_with_kvcache(
378
+ q,
379
+ kv_cache[:, :, 0],
380
+ kv_cache[:, :, 1],
381
+ kv[:, :, 0],
382
+ kv[:, :, 1],
383
+ cache_seqlens=cache_seqlens,
384
+ softmax_scale=self.inner_cross_attn.softmax_scale,
385
+ causal=self.inner_cross_attn.causal,
386
+ alibi_slopes=alibi_slopes,
387
+ )
388
+
389
+ def forward(
390
+ self,
391
+ x,
392
+ x_kv=None,
393
+ key_padding_mask=None,
394
+ cu_seqlens=None,
395
+ max_seqlen=None,
396
+ mixer_subset=None,
397
+ inference_params=None,
398
+ batch_head_idx=None,
399
+ use_triton=True,
400
+ **kwargs,
401
+ ):
402
+ """
403
+ Arguments:
404
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
405
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
406
+ is the is the sum of the sequence lengths in the batch.
407
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
408
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
409
+ of the sequences in the batch, used to index into x. Only applicable when using
410
+ FlashAttention.
411
+ max_seqlen: int. Maximum sequence length in the batch.
412
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
413
+ (batch, seqlen). Only applicable when not using FlashAttention.
414
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
415
+ before applying the query projection. Useful for e.g., ViT where we only care
416
+ about the CLS token in the last layer.
417
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
418
+ batch_head_idx: (batch, num_heads). The index of the heads to be selected. Only applicable for Selective Head/Group Attention.
419
+ use_triton: whether to use triton kernels for attention in decode. If False, use the original flash attention implementation.
420
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
421
+ """
422
+ if cu_seqlens is not None:
423
+ assert max_seqlen is not None
424
+ assert key_padding_mask is None
425
+ assert self.use_flash_attn
426
+ assert not self.dwconv
427
+ assert self.rotary_emb_dim == 0
428
+ if key_padding_mask is not None:
429
+ assert cu_seqlens is None
430
+ assert max_seqlen is None
431
+ assert not self.use_flash_attn
432
+ if inference_params is not None:
433
+ assert key_padding_mask is None
434
+ assert cu_seqlens is None and max_seqlen is None
435
+ assert not self.dwconv
436
+ # use_triton = self.use_triton if use_triton is None else use_triton
437
+ kwargs = (
438
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
439
+ if self.use_flash_attn
440
+ else {"key_padding_mask": key_padding_mask, **kwargs}
441
+ )
442
+ seqlen_offset = (
443
+ 0
444
+ if inference_params is None
445
+ else (
446
+ inference_params.lengths_per_sample
447
+ if inference_params.lengths_per_sample is not None
448
+ else inference_params.seqlen_offset
449
+ )
450
+ )
451
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
452
+ batch, seqlen = x.shape[:2]
453
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
454
+ assert x_kv is None and mixer_subset is None
455
+ if not self.return_residual:
456
+ qkv = self.Wqkv(x)
457
+ else:
458
+ qkv, x = self.Wqkv(x)
459
+ if self.dwconv:
460
+ qkv = rearrange(
461
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
462
+ ).contiguous()
463
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
464
+ if (
465
+ inference_params is None
466
+ or inference_params.seqlen_offset == 0
467
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
468
+ or not self.use_flash_attn
469
+ ):
470
+ # prefill stage
471
+ if self.rotary_emb_dim > 0:
472
+ qkv = self.rotary_emb(
473
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
474
+ )
475
+ if inference_params is None:
476
+ if not self.checkpointing:
477
+ context = self.inner_attn(qkv, **kwargs)
478
+ else:
479
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
480
+ else:
481
+ if use_triton:
482
+ # print("Using the (prefill) triton flash attention implementation")
483
+ context = self._update_kvcache_attention_triton(
484
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx
485
+ )
486
+ else:
487
+ # print("Using the (prefill) original flash attention implementation")
488
+ context = self._update_kvcache_attention(
489
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
490
+ )
491
+ else:
492
+ # decode stage
493
+ # print("Using triton kernels for attention")
494
+ if use_triton:
495
+ if self.rotary_emb_dim > 0:
496
+ qkv = self.rotary_emb(
497
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
498
+ )
499
+ context = self._update_kvcache_attention_triton(
500
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx
501
+ )
502
+ else:
503
+ # print("Using the original flash attention implementation")
504
+ context = self._apply_rotary_update_kvcache_attention(
505
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
506
+ )
507
+
508
+ else: # cross-attention or MQA/GQA
509
+ if self.cross_attn:
510
+ if not self.return_residual:
511
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
512
+ kv = self.Wkv(x_kv if x_kv is not None else x)
513
+ else:
514
+ if x_kv is not None:
515
+ kv, x_kv = self.Wkv(x_kv)
516
+ else:
517
+ kv, x = self.Wkv(x)
518
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
519
+ else:
520
+ assert self.num_heads_kv != self.num_heads
521
+ if not self.return_residual:
522
+ qkv = self.Wqkv(x)
523
+ else:
524
+ qkv, x = self.Wqkv(x)
525
+ q = qkv[..., : self.num_heads * self.head_dim]
526
+ kv = qkv[..., self.num_heads * self.head_dim :]
527
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
528
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
529
+ if self.dwconv:
530
+ q = rearrange(
531
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
532
+ ).contiguous()
533
+ kv = rearrange(
534
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
535
+ ).contiguous()
536
+ if (
537
+ inference_params is None
538
+ or inference_params.seqlen_offset == 0
539
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
540
+ or not self.use_flash_attn
541
+ ):
542
+ # prefill
543
+ if self.rotary_emb_dim > 0:
544
+ q, kv = self.rotary_emb(
545
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
546
+ )
547
+ if inference_params is None:
548
+ if not self.checkpointing:
549
+ context = self.inner_cross_attn(q, kv, **kwargs)
550
+ else:
551
+ context = torch.utils.checkpoint.checkpoint(
552
+ self.inner_cross_attn, q, kv, **kwargs
553
+ )
554
+ else:
555
+ if use_triton:
556
+ context = self._update_kvcache_attention_triton(
557
+ q, kv, inference_params, batch_head_idx
558
+ )
559
+ else:
560
+ context = self._update_kvcache_attention(q, kv, inference_params)
561
+ else:
562
+ # decode
563
+ # print("Using triton kernels for attention")
564
+ if use_triton:
565
+ if self.rotary_emb_dim > 0:
566
+ q, kv = self.rotary_emb(
567
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
568
+ )
569
+ context = self._update_kvcache_attention_triton(
570
+ q, kv, inference_params, batch_head_idx
571
+ )
572
+ else:
573
+ # print("Using the original gqa flash attention implementation")
574
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
575
+ # print(f"Context.shape: {context.shape}")
576
+ # print(f"Context: {context}")
577
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
578
+ return out if not self.return_residual else (out, x)
579
+
580
+ class ParallelSMHA(nn.Module):
581
+ """Multi-head self-attention and cross-attention"""
582
+
583
+ def __init__(
584
+ self,
585
+ embed_dim,
586
+ num_heads,
587
+ process_group,
588
+ num_heads_kv=None,
589
+ qkv_proj_bias=True,
590
+ out_proj_bias=True,
591
+ dropout=0.0,
592
+ softmax_scale=None,
593
+ causal=False,
594
+ layer_idx=None,
595
+ rotary_emb_dim=0,
596
+ rotary_emb_base=10000.0,
597
+ rotary_emb_scale_base=None,
598
+ rotary_emb_interleaved=False,
599
+ use_alibi=False,
600
+ window_size=(-1, -1),
601
+ use_flash_attn=False,
602
+ checkpointing=False,
603
+ sequence_parallel=True,
604
+ device=None,
605
+ dtype=None,
606
+ ) -> None:
607
+ factory_kwargs = {"device": device, "dtype": dtype}
608
+ super().__init__()
609
+ self.embed_dim = embed_dim
610
+ self.causal = causal
611
+ self.layer_idx = layer_idx
612
+ self.rotary_emb_dim = rotary_emb_dim
613
+ self.use_flash_attn = use_flash_attn
614
+ self.checkpointing = checkpointing
615
+ self.process_group = process_group
616
+ self.world_size = process_group.size()
617
+ self.local_rank = torch.distributed.get_rank(process_group)
618
+
619
+ self.num_heads = num_heads
620
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
621
+
622
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
623
+ assert (
624
+ self.num_heads % self.num_heads_kv == 0
625
+ ), "num_heads must be divisible by num_heads_kv"
626
+
627
+ self.num_heads_per_rank = get_dim_for_local_rank(
628
+ self.num_heads, self.world_size, self.local_rank
629
+ )
630
+ self.num_heads_kv_per_rank = get_dim_for_local_rank(
631
+ self.num_heads_kv, self.world_size, self.local_rank
632
+ )
633
+ self.head_dim = self.embed_dim // num_heads
634
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
635
+
636
+ if use_alibi:
637
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
638
+ num_heads_local = math.ceil(self.num_heads / self.world_size)
639
+ alibi_slopes = torch.tensor(
640
+ get_alibi_slopes(num_heads)[
641
+ self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
642
+ ],
643
+ device=device,
644
+ )
645
+ else:
646
+ alibi_slopes = None
647
+ if window_size != (-1, -1):
648
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
649
+
650
+ if self.rotary_emb_dim > 0:
651
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
652
+ self.rotary_emb = RotaryEmbedding(
653
+ self.rotary_emb_dim,
654
+ base=rotary_emb_base,
655
+ scale_base=rotary_emb_scale_base,
656
+ interleaved=rotary_emb_interleaved,
657
+ device=device,
658
+ )
659
+
660
+ if ColumnParallelLinear is None or RowParallelLinear is None:
661
+ raise ImportError("fused_dense is not installed")
662
+ self.Wqkv = ColumnParallelLinear(
663
+ embed_dim,
664
+ qkv_dim,
665
+ process_group,
666
+ bias=qkv_proj_bias,
667
+ sequence_parallel=sequence_parallel,
668
+ multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
669
+ **factory_kwargs,
670
+ )
671
+ inner_attn_cls = (
672
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
673
+ if use_flash_attn
674
+ else SelfAttention
675
+ )
676
+ inner_cross_attn_cls = (
677
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
678
+ if use_flash_attn
679
+ else CrossAttention
680
+ )
681
+ self.inner_attn = inner_attn_cls(
682
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
683
+ )
684
+ self.inner_cross_attn = inner_cross_attn_cls(
685
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
686
+ )
687
+ self.out_proj = RowParallelLinear(
688
+ embed_dim,
689
+ embed_dim,
690
+ process_group,
691
+ bias=out_proj_bias,
692
+ sequence_parallel=sequence_parallel,
693
+ multiple_of=self.head_dim,
694
+ **factory_kwargs,
695
+ )
696
+
697
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
698
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
699
+ device = self.out_proj.weight.device
700
+ return torch.empty(
701
+ batch_size,
702
+ max_seqlen,
703
+ 2,
704
+ self.num_heads_kv_per_rank,
705
+ self.head_dim,
706
+ dtype=dtype,
707
+ device=device,
708
+ )
709
+
710
+ def _update_kv_cache(self, kv, inference_params):
711
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
712
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
713
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
714
+
715
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
716
+ """
717
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
718
+ q: (batch_size, seqlen_q, nheads, head_dim)
719
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
720
+ """
721
+ assert inference_params is not None and inference_params.seqlen_offset > 0
722
+ assert self.use_flash_attn
723
+ if self.rotary_emb_dim > 0:
724
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
725
+ self.rotary_emb._update_cos_sin_cache(
726
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
727
+ )
728
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
729
+ else:
730
+ rotary_cos, rotary_sin = None, None
731
+ batch = q.shape[0]
732
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
733
+ cache_seqlens = (
734
+ inference_params.lengths_per_sample[:batch]
735
+ if inference_params.lengths_per_sample is not None
736
+ else inference_params.seqlen_offset
737
+ )
738
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
739
+ context = flash_attn_with_kvcache(
740
+ q,
741
+ kv_cache[:, :, 0],
742
+ kv_cache[:, :, 1],
743
+ kv[:, :, 0],
744
+ kv[:, :, 1],
745
+ rotary_cos=rotary_cos,
746
+ rotary_sin=rotary_sin,
747
+ cache_seqlens=cache_seqlens,
748
+ softmax_scale=self.inner_cross_attn.softmax_scale,
749
+ causal=self.inner_cross_attn.causal,
750
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
751
+ alibi_slopes=alibi_slopes,
752
+ )
753
+ return context
754
+
755
+ def _update_kvcache_attention(self, q, kv, inference_params):
756
+ """Write kv to inference_params, then do attention"""
757
+ if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
758
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
759
+ kv = self._update_kv_cache(kv, inference_params)
760
+ return self.inner_cross_attn(q, kv)
761
+ else:
762
+ batch = q.shape[0]
763
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
764
+ cache_seqlens = (
765
+ inference_params.lengths_per_sample[:batch]
766
+ if inference_params.lengths_per_sample is not None
767
+ else inference_params.seqlen_offset
768
+ )
769
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
770
+ context = flash_attn_with_kvcache(
771
+ q,
772
+ kv_cache[:, :, 0],
773
+ kv_cache[:, :, 1],
774
+ kv[:, :, 0],
775
+ kv[:, :, 1],
776
+ cache_seqlens=cache_seqlens,
777
+ softmax_scale=self.inner_cross_attn.softmax_scale,
778
+ causal=self.inner_cross_attn.causal,
779
+ alibi_slopes=alibi_slopes,
780
+ )
781
+ return context
782
+
783
+ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
784
+ """
785
+ Arguments:
786
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
787
+ If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
788
+ split x during sequence parallel, we split the batch * seqlen dimension
789
+ (in case batch is small).
790
+ """
791
+ qkv = self.Wqkv(x)
792
+ if seqlen is not None:
793
+ qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
794
+ seqlen_offset = (
795
+ 0
796
+ if inference_params is None
797
+ else (
798
+ inference_params.lengths_per_sample
799
+ if inference_params.lengths_per_sample is not None
800
+ else inference_params.seqlen_offset
801
+ )
802
+ )
803
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
804
+ if self.num_heads_kv == self.num_heads:
805
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
806
+ if (
807
+ inference_params is None
808
+ or inference_params.seqlen_offset == 0
809
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
810
+ or not self.use_flash_attn
811
+ ):
812
+ if self.rotary_emb_dim > 0:
813
+ qkv = self.rotary_emb(
814
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
815
+ )
816
+ if inference_params is None:
817
+ if not self.checkpointing:
818
+ context = self.inner_attn(qkv, **kwargs)
819
+ else:
820
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
821
+ else:
822
+ context = self._update_kvcache_attention(
823
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
824
+ )
825
+ else:
826
+ context = self._apply_rotary_update_kvcache_attention(
827
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
828
+ )
829
+ else: # GQA/MQA
830
+ q = rearrange(
831
+ qkv[..., : self.num_heads_per_rank * self.head_dim],
832
+ "... (h d) -> ... h d",
833
+ d=self.head_dim,
834
+ )
835
+ kv = rearrange(
836
+ qkv[..., self.num_heads_per_rank * self.head_dim :],
837
+ "... (two hkv d) -> ... two hkv d",
838
+ two=2,
839
+ d=self.head_dim,
840
+ )
841
+ if (
842
+ inference_params is None
843
+ or inference_params.seqlen_offset == 0
844
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
845
+ or not self.use_flash_attn
846
+ ):
847
+ if self.rotary_emb_dim > 0:
848
+ q, kv = self.rotary_emb(
849
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
850
+ )
851
+ if inference_params is None:
852
+ if not self.checkpointing:
853
+ context = self.inner_cross_attn(q, kv, **kwargs)
854
+ else:
855
+ context = torch.utils.checkpoint.checkpoint(
856
+ self.inner_cross_attn, q, kv, **kwargs
857
+ )
858
+ else:
859
+ context = self._update_kvcache_attention(q, kv, inference_params)
860
+ else:
861
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
862
+ context = rearrange(context, "b s h d -> b s (h d)")
863
+ if seqlen is not None:
864
+ context = rearrange(context, "b s d -> (b s) d")
865
+ out = self.out_proj(context)
866
+ return out
867
+
868
+ class SelectMHA(nn.Module):
869
+ """Multi-head, Group-query self-attention using select attention"""
870
+
871
+ def __init__(
872
+ self,
873
+ embed_dim,
874
+ num_heads,
875
+ num_heads_kv=None,
876
+ cross_attn=False,
877
+ qkv_proj_bias=True,
878
+ out_proj_bias=True,
879
+ dropout=0.0,
880
+ softmax_scale=None,
881
+ causal=False,
882
+ layer_idx=None,
883
+ dwconv=False,
884
+ rotary_emb_dim=0,
885
+ rotary_emb_base=10000.0,
886
+ rotary_emb_scale_base=None,
887
+ rotary_emb_interleaved=False,
888
+ use_alibi=False,
889
+ window_size=(-1, -1),
890
+ fused_bias_fc=False,
891
+ use_flash_attn=True,
892
+ return_residual=False,
893
+ checkpointing=False,
894
+ device=None,
895
+ dtype=None,
896
+ ) -> None:
897
+ """
898
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
899
+ return_residual: whether to return the input x along with the output. This is for
900
+ performance reason: for post-norm architecture, returning the input allows us
901
+ to fuse the backward of nn.Linear with the residual connection.
902
+ """
903
+ factory_kwargs = {"device": device, "dtype": dtype}
904
+ super().__init__()
905
+ self.embed_dim = embed_dim
906
+ self.cross_attn = cross_attn
907
+ self.causal = causal
908
+ self.layer_idx = layer_idx
909
+ self.dwconv = dwconv
910
+ self.rotary_emb_dim = rotary_emb_dim
911
+ self.use_flash_attn = True # use_flash_attn
912
+ self.return_residual = return_residual
913
+ self.checkpointing = checkpointing
914
+ if use_alibi:
915
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
916
+ alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
917
+ else:
918
+ alibi_slopes = None
919
+ if window_size != (-1, -1):
920
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
921
+
922
+ self.num_heads = num_heads
923
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
924
+ assert (
925
+ self.num_heads % self.num_heads_kv == 0
926
+ ), "num_heads must be divisible by num_heads_kv"
927
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
928
+ self.head_dim = self.embed_dim // num_heads
929
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
930
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
931
+
932
+ if self.rotary_emb_dim > 0:
933
+ assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
934
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
935
+ self.rotary_emb = RotaryEmbedding(
936
+ self.rotary_emb_dim,
937
+ base=rotary_emb_base,
938
+ scale_base=rotary_emb_scale_base,
939
+ interleaved=rotary_emb_interleaved,
940
+ device=device,
941
+ )
942
+
943
+ if fused_bias_fc and FusedDense is None:
944
+ raise ImportError("fused_dense is not installed")
945
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
946
+ linear_resid_cls = (
947
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
948
+ )
949
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
950
+ inner_attn_cls = (
951
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
952
+ if use_flash_attn
953
+ else SelfAttention
954
+ )
955
+
956
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
957
+
958
+ self.inner_attn = inner_attn_cls(
959
+ causal=causal,
960
+ softmax_scale=softmax_scale,
961
+ attention_dropout=dropout,
962
+ )
963
+ self.softmax_scale = softmax_scale
964
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
965
+
966
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
967
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
968
+ device = self.out_proj.weight.device
969
+ return torch.empty(
970
+ batch_size,
971
+ max_seqlen,
972
+ 2,
973
+ self.num_heads_kv,
974
+ self.head_dim,
975
+ dtype=dtype,
976
+ device=device,
977
+ )
978
+
979
+ def _update_kv_cache(self, kv, inference_params):
980
+ """Update kv cache in inference_params."""
981
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
982
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
983
+
984
+ def forward(
985
+ self,
986
+ x,
987
+ x_kv=None,
988
+ key_padding_mask=None,
989
+ cu_seqlens=None,
990
+ max_seqlen=None,
991
+ mixer_subset=None,
992
+ inference_params=None,
993
+ batch_head_idx=None,
994
+ **kwargs,
995
+ ):
996
+ """
997
+ Arguments:
998
+ x: (batch, seqlen, hidden_dim)
999
+ batch_head_idx: Tensor of indices specifying which batch and head indices to select.
1000
+ Shape: (batch_size, top_k)
1001
+ inference_params: for generation.
1002
+ """
1003
+ seqlen_offset = (
1004
+ 0
1005
+ if inference_params is None
1006
+ else (
1007
+ inference_params.lengths_per_sample
1008
+ if inference_params.lengths_per_sample is not None
1009
+ else inference_params.seqlen_offset
1010
+ )
1011
+ )
1012
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
1013
+ batch, seqlen = x.shape[:2]
1014
+
1015
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
1016
+ # Self-attention, no MQA/GQA
1017
+ assert x_kv is None and mixer_subset is None
1018
+ if not self.return_residual:
1019
+ qkv = self.Wqkv(x)
1020
+ else:
1021
+ qkv, x = self.Wqkv(x)
1022
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
1023
+
1024
+ if self.rotary_emb_dim > 0:
1025
+ qkv = self.rotary_emb(
1026
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
1027
+ )
1028
+
1029
+ if inference_params is None or inference_params.seqlen_offset == 0:
1030
+ # Inference stage without inference_params
1031
+ if inference_params is not None:
1032
+ # Update kv cache during prefill
1033
+ kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
1034
+
1035
+ context = self.inner_attn(qkv, **kwargs)
1036
+
1037
+ else:
1038
+ # Generation stage
1039
+ if batch_head_idx is None:
1040
+ # Apply select attention without kv cache update
1041
+ context = self._update_kvcache_attention(q = qkv[:, :, 0], kv = qkv[:, :, 1:], inference_params = inference_params)
1042
+ else:
1043
+ # Apply select attention with kv cache update
1044
+ context = self._update_kvcache_select_attn(q = qkv[:, :, 0], kv = qkv[:, :, 1:], inference_params = inference_params, batch_head_idx = batch_head_idx)
1045
+
1046
+ else:
1047
+ raise NotImplementedError("SelectMHA currently supports only self-attention without MQA/GQA.")
1048
+
1049
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
1050
+ return out if not self.return_residual else (out, x)
1051
+
1052
+ def _update_kvcache_select_attn(self, q, kv, inference_params, batch_head_idx):
1053
+ """
1054
+ Apply select attention during generation stage.
1055
+
1056
+ q: (batch_size, seqlen=1, n_heads, head_dim)
1057
+ kv: (batch_size, seqlen=1, 2, n_heads, head_dim)
1058
+ batch_head_idx: Tensor of indices specifying which batch and head indices to select.
1059
+ Shape: (batch_size, top_k)
1060
+
1061
+ # currently only supports batches with same seqlen
1062
+ # different seqlen requires a simple update in the select_attn kernel to load the seqlen, future work
1063
+ """
1064
+ # check batch_head_idx shape
1065
+ # assert batch_head_idx.shape[0] == 2, "batch_head_idx must have shape (N_selected, 2)"
1066
+ # check batch_head_idx is not None
1067
+ assert batch_head_idx is not None, "batch_head_idx must not be None"
1068
+
1069
+ # update kv cache
1070
+ kv_cache = self._update_kv_cache(kv, inference_params)
1071
+ # inference_params.seqlen_offset += 1 # if seqlen_offset is int
1072
+
1073
+ batch = q.shape[0]
1074
+ # kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
1075
+
1076
+ # make sure seqlen_offset accounts for the current token
1077
+ cache_seqlens = (
1078
+ inference_params.lengths_per_sample[:batch]
1079
+ if inference_params.lengths_per_sample is not None
1080
+ else inference_params.seqlen_offset + 1 # +1 for the current token
1081
+ )
1082
+
1083
+ # need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim)
1084
+ q = q.unsqueeze(2)
1085
+ k_cache = kv_cache[:, :, 0].unsqueeze(2)
1086
+ v_cache = kv_cache[:, :, 1].unsqueeze(2)
1087
+
1088
+ # Call select_attn
1089
+ context = select_attn(
1090
+ q,
1091
+ k_cache,
1092
+ v_cache,
1093
+ self.softmax_scale,
1094
+ batch_head_idx,
1095
+ cache_seqlens)
1096
+
1097
+ # context: (batch_size, seqlen_q=1, G=1, H, head_dim)
1098
+ # context = context.squeeze(2) # Remove G dimension
1099
+ batch_size = batch_head_idx.shape[0]
1100
+ context = context.view(batch_size, 1, self.num_heads, self.head_dim)
1101
+
1102
+ return context
1103
+
1104
+ def _update_kvcache_attention(self, q, kv, inference_params):
1105
+ """Write kv to inference_params, then do attention"""
1106
+ if (
1107
+ inference_params.seqlen_offset == 0
1108
+ or flash_attn_with_kvcache is None
1109
+ or not self.use_flash_attn
1110
+ ):
1111
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
1112
+ kv = self._update_kv_cache(kv, inference_params)
1113
+ return self.inner_cross_attn(q, kv)
1114
+ else:
1115
+ batch = q.shape[0]
1116
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
1117
+ cache_seqlens = (
1118
+ inference_params.lengths_per_sample[:batch]
1119
+ if inference_params.lengths_per_sample is not None
1120
+ else inference_params.seqlen_offset
1121
+ )
1122
+ # alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
1123
+ alibi_slopes = None
1124
+ return flash_attn_with_kvcache(
1125
+ q,
1126
+ kv_cache[:, :, 0],
1127
+ kv_cache[:, :, 1],
1128
+ kv[:, :, 0],
1129
+ kv[:, :, 1],
1130
+ cache_seqlens=cache_seqlens,
1131
+ softmax_scale=self.inner_attn.softmax_scale,
1132
+ causal=self.inner_attn.causal,
1133
+ alibi_slopes=alibi_slopes,
1134
+ )
1135
+
1136
+ # dummy function for testing
1137
+ def _select_attn(self, q, kv, inference_params, batch_head_idx):
1138
+ """
1139
+ Apply select attention during generation stage.
1140
+
1141
+ q: (batch_size, seqlen=1, n_heads, head_dim)
1142
+ kv: (batch_size, seqlen=1, 2, n_heads, head_dim)
1143
+ batch_head_idx: Tensor of indices specifying which batch and head indices to select.
1144
+ Shape: (N_selected, 2)
1145
+
1146
+ # currently only supports batches with same seqlen
1147
+ # different seqlen requires a simple update in the select_attn kernel to load the seqlen, future work
1148
+ """
1149
+ # check batch_head_idx shape
1150
+ assert batch_head_idx.shape[1] == 2, "batch_head_idx must have shape (N_selected, 2)"
1151
+
1152
+ # update kv cache
1153
+ # kv_cache = self._update_kv_cache(kv, inference_params)
1154
+ # inference_params.seqlen_offset += 1 # if seqlen_offset is int
1155
+
1156
+ batch = q.shape[0]
1157
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
1158
+
1159
+ # make sure seqlen_offset accounts for the current token
1160
+ cache_seqlens = (
1161
+ inference_params.lengths_per_sample[:batch]
1162
+ if inference_params.lengths_per_sample is not None
1163
+ else inference_params.seqlen_offset # +1 for the current token
1164
+ )
1165
+
1166
+ # need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim)
1167
+ q = q.unsqueeze(2)
1168
+ k_cache = kv_cache[:, :, 0].unsqueeze(2)
1169
+ v_cache = kv_cache[:, :, 1].unsqueeze(2)
1170
+
1171
+ # Call select_attn
1172
+ context = select_attn(
1173
+ q,
1174
+ k_cache,
1175
+ v_cache,
1176
+ self.softmax_scale,
1177
+ batch_head_idx,
1178
+ cache_seqlens)
1179
+
1180
+ # context: (batch_size, seqlen_q=1, G=1, H, head_dim)
1181
+ context = context.squeeze(2) # Remove G dimension
1182
+
1183
+ return context
1184
+
1185
+
1186
+ # SelectiveGQA: Future work
1187
+
1188
+ class ParallelSelectMHA(nn.Module):
1189
+ def __init__(
1190
+ self,
1191
+ embed_dim,
1192
+ num_heads,
1193
+ process_group,
1194
+ num_heads_kv=None,
1195
+ qkv_proj_bias=True,
1196
+ out_proj_bias=True,
1197
+ dropout=0.0,
1198
+ softmax_scale=None,
1199
+ causal=True,
1200
+ layer_idx=None,
1201
+ dwconv=False,
1202
+ rotary_emb_dim=0,
1203
+ rotary_emb_base=10000.0,
1204
+ rotary_emb_scale_base=None,
1205
+ rotary_emb_interleaved=False,
1206
+ use_alibi=False,
1207
+ window_size=(-1, -1),
1208
+ fused_bias_fc=True,
1209
+ use_flash_attn=True,
1210
+ return_residual=False,
1211
+ checkpointing=False,
1212
+ sequence_parallel=False,
1213
+ device=None,
1214
+ dtype=None,
1215
+ ) -> None:
1216
+ """
1217
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
1218
+ return_residual: whether to return the input x along with the output. This is for
1219
+ performance reason: for post-norm architecture, returning the input allows us
1220
+ to fuse the backward of nn.Linear with the residual connection.
1221
+ """
1222
+ factory_kwargs = {"device": device, "dtype": dtype}
1223
+ super().__init__()
1224
+ self.embed_dim = embed_dim
1225
+ self.causal = causal
1226
+ self.layer_idx = layer_idx
1227
+ self.dwconv = dwconv
1228
+ self.rotary_emb_dim = rotary_emb_dim
1229
+ self.use_flash_attn = use_flash_attn
1230
+ self.return_residual = return_residual
1231
+ self.checkpointing = checkpointing
1232
+ self.process_group = process_group
1233
+ self.world_size = process_group.size()
1234
+ self.local_rank = torch.distributed.get_rank(process_group)
1235
+
1236
+ self.num_heads = num_heads
1237
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
1238
+
1239
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
1240
+ assert (
1241
+ self.num_heads % self.num_heads_kv == 0
1242
+ ), "num_heads must be divisible by num_heads_kv"
1243
+
1244
+ self.num_heads_per_rank = get_dim_for_local_rank(
1245
+ self.num_heads, self.world_size, self.local_rank
1246
+ )
1247
+ self.num_heads_kv_per_rank = get_dim_for_local_rank(
1248
+ self.num_heads_kv, self.world_size, self.local_rank
1249
+ )
1250
+ self.head_dim = self.embed_dim // num_heads
1251
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
1252
+
1253
+ if use_alibi:
1254
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
1255
+ num_heads_local = math.ceil(self.num_heads / self.world_size)
1256
+ alibi_slopes = torch.tensor(
1257
+ get_alibi_slopes(num_heads)[
1258
+ self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
1259
+ ],
1260
+ device=device,
1261
+ )
1262
+ else:
1263
+ alibi_slopes = None
1264
+ if window_size != (-1, -1):
1265
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
1266
+
1267
+ if self.rotary_emb_dim > 0:
1268
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
1269
+ self.rotary_emb = RotaryEmbedding(
1270
+ self.rotary_emb_dim,
1271
+ base=rotary_emb_base,
1272
+ scale_base=rotary_emb_scale_base,
1273
+ interleaved=rotary_emb_interleaved,
1274
+ device=device,
1275
+ )
1276
+
1277
+ if ColumnParallelLinear is None or RowParallelLinear is None:
1278
+ raise ImportError("fused_dense is not installed")
1279
+ self.Wqkv = ColumnParallelLinear(
1280
+ embed_dim,
1281
+ qkv_dim,
1282
+ process_group,
1283
+ bias=qkv_proj_bias,
1284
+ sequence_parallel=sequence_parallel,
1285
+ multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
1286
+ **factory_kwargs,
1287
+ )
1288
+
1289
+ inner_attn_cls = (
1290
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
1291
+ if use_flash_attn
1292
+ else SelfAttention
1293
+ )
1294
+
1295
+ self.inner_attn = inner_attn_cls(
1296
+ causal=causal,
1297
+ softmax_scale=softmax_scale,
1298
+ attention_dropout=dropout,
1299
+ )
1300
+ self.softmax_scale = softmax_scale
1301
+
1302
+ # replace this with no reduce
1303
+ self.out_proj = RowParallelLinear(
1304
+ embed_dim,
1305
+ embed_dim,
1306
+ process_group,
1307
+ bias=out_proj_bias,
1308
+ sequence_parallel=sequence_parallel,
1309
+ multiple_of=self.head_dim,
1310
+ **factory_kwargs,
1311
+ )
1312
+
1313
+ self.mha_router = None
1314
+ self.mlp_router = None
1315
+ # We'll use an extra stream for concurrency
1316
+ self.current_stream = None
1317
+ self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0)
1318
+ self.main_stream = torch.cuda.Stream(device="cuda", priority=-5)
1319
+ self.mha_router_event = torch.cuda.Event(enable_timing=False, blocking=False)
1320
+ self.mlp_router_event = torch.cuda.Event(enable_timing=False, blocking=False)
1321
+ self.main_event = torch.cuda.Event(enable_timing=False, blocking=False)
1322
+
1323
+ # self.local_head_idx = generate_random_BH_index(1, self.num_heads_per_rank,self.num_heads_per_rank)
1324
+
1325
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
1326
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
1327
+ device = self.out_proj.weight.device
1328
+ return torch.empty(
1329
+ batch_size,
1330
+ max_seqlen,
1331
+ 2,
1332
+ self.num_heads_kv_per_rank,
1333
+ self.head_dim,
1334
+ dtype=dtype,
1335
+ device=device,
1336
+ )
1337
+
1338
+ def _update_kv_cache(self, kv, inference_params):
1339
+ """Update kv cache in inference_params."""
1340
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
1341
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
1342
+
1343
+ def forward(
1344
+ self,
1345
+ x,
1346
+ seqlen=None,
1347
+ inference_params=None,
1348
+ batch_head_idx=None,
1349
+ **kwargs,
1350
+ ):
1351
+ """
1352
+ Arguments:
1353
+ x: (batch, seqlen, hidden_dim)
1354
+ batch_head_idx: Tensor of indices specifying which batch and head indices to select.
1355
+ Shape: (N_selected,)
1356
+ inference_params: for generation.
1357
+ """
1358
+
1359
+ router_inputs = x.squeeze(1)
1360
+ self.current_stream = torch.cuda.current_stream()
1361
+ self.main_stream.wait_stream(self.current_stream )
1362
+ self.sparse_stream.wait_stream(self.current_stream )
1363
+
1364
+ is_decode = inference_params is not None and inference_params.seqlen_offset > 0
1365
+
1366
+ # if self.mha_router and is_decode:
1367
+ # with torch.cuda.stream(self.sparse_stream):
1368
+ # batch_head_idx = self.mha_router._select_heads(router_inputs)
1369
+ # self.sparse_stream.record_event(self.mha_router_event)
1370
+
1371
+
1372
+ qkv = self.Wqkv(x)
1373
+ if seqlen is not None:
1374
+ qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
1375
+
1376
+ seqlen_offset = (
1377
+ 0
1378
+ if inference_params is None
1379
+ else (
1380
+ inference_params.lengths_per_sample
1381
+ if inference_params.lengths_per_sample is not None
1382
+ else inference_params.seqlen_offset
1383
+ )
1384
+ )
1385
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
1386
+ batch, seqlen = x.shape[:2]
1387
+
1388
+ if self.num_heads_kv == self.num_heads:
1389
+ # Self-attention, no MQA/GQA
1390
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
1391
+
1392
+ if self.rotary_emb_dim > 0:
1393
+ qkv = self.rotary_emb(
1394
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
1395
+ )
1396
+ if inference_params is None or inference_params.seqlen_offset == 0:
1397
+ # Inference stage without inference_params, prefill stage
1398
+ if inference_params is not None:
1399
+ # Update kv cache during prefill
1400
+ kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
1401
+
1402
+ context = self.inner_attn(qkv, **kwargs)
1403
+ else:
1404
+ # Generation stage
1405
+
1406
+ # apply rotary embeddings
1407
+ if self.rotary_emb_dim > 0:
1408
+ qkv = self.rotary_emb(
1409
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
1410
+ )
1411
+
1412
+ # Apply select attention with kv cache update
1413
+ context = self._update_kvcache_select_attn(qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx)
1414
+ else: # cross-attention, MQA/GQA
1415
+ raise NotImplementedError("SelectMHA currently supports only self-attention without MQA/GQA.")
1416
+
1417
+ context = rearrange(context, "b s h d -> b s (h d)")
1418
+ if seqlen is not None:
1419
+ context = rearrange(context, "b s d -> (b s) d")
1420
+ # out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
1421
+ # out = self.out_proj(context)
1422
+
1423
+ out = fused_dense_func(context, self.out_proj.weight, self.out_proj.bias)
1424
+
1425
+ # if is_decode:
1426
+ # if self.mlp_router:
1427
+ # with torch.cuda.stream(self.sparse_stream):
1428
+ # index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
1429
+ # self.sparse_stream.record_event(self.mlp_router_event)
1430
+
1431
+ # with torch.cuda.stream(self.main_stream):
1432
+ # out = all_reduce(out, self.process_group)
1433
+ # self.main_stream.record_event(self.main_event)
1434
+
1435
+ # self.current_stream.wait_event(self.mlp_router_event)
1436
+ # self.current_stream.wait_event(self.main_event)
1437
+
1438
+ # # index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
1439
+ # # out = all_reduce(out, self.process_group)
1440
+
1441
+ # return out, index_vec
1442
+ # else:
1443
+ # out = all_reduce(out, self.process_group)
1444
+ out = all_reduce(out, self.process_group)
1445
+ return out if not self.return_residual else (out, x)
1446
+ # return out
1447
+
1448
+ def _update_kvcache_select_attn(self, q, kv, inference_params, batch_head_idx = None):
1449
+ """
1450
+ Apply select attention during generation stage.
1451
+
1452
+ q: (batch_size, seqlen=1, n_heads, head_dim)
1453
+ kv: (batch_size, seqlen=1, 2, n_heads, head_dim)
1454
+ batch_head_idx: Tensor of indices specifying which batch and head indices to select.
1455
+ Shape: (batch_size, top_k)
1456
+ """
1457
+ # check batch_head_idx shape
1458
+ # assert batch_head_idx.shape[1] == 2, "batch_head_idx must have shape (N_selected, 2)"
1459
+
1460
+ # if batch_head_idx is None:
1461
+ # batch_head_idx = self.local_head_idx
1462
+ # print("Using local_head_idx, router not used.")
1463
+ # batch_head_idx = self.local_head_idx
1464
+ # print("Using local_head_idx, router not used.")
1465
+
1466
+ # update kv cache
1467
+ kv_cache = self._update_kv_cache(kv, inference_params)
1468
+
1469
+ batch = q.shape[0]
1470
+ # kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
1471
+ cache_seqlens = (
1472
+ inference_params.lengths_per_sample[:batch]
1473
+ if inference_params.lengths_per_sample is not None
1474
+ else inference_params.seqlen_offset + 1 # +1 for the current token
1475
+ )
1476
+ # need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim)
1477
+ q = q.unsqueeze(2)
1478
+ k_cache = kv_cache[:, :, 0].unsqueeze(2)
1479
+ v_cache = kv_cache[:, :, 1].unsqueeze(2)
1480
+
1481
+ self.current_stream.wait_event(self.mha_router_event)
1482
+
1483
+ assert batch_head_idx is not None, "batch_head_idx must not be None"
1484
+ # Call select_attn
1485
+ context = select_attn(
1486
+ q,
1487
+ k_cache,
1488
+ v_cache,
1489
+ self.softmax_scale,
1490
+ batch_head_idx,
1491
+ cache_seqlens
1492
+ )
1493
+
1494
+ # context: (batch_size, seqlen_q=1, G=1, H, head_dim)
1495
+ # context = context.squeeze(2) # Remove G dimension
1496
+ context = context.view(batch, 1, self.num_heads_kv_per_rank, self.head_dim)
1497
+ return context
1498
+
1499
+ '''
1500
+ PYTHONWARNINGS="ignore" python -m HybridTensor.modules.SelectiveMHA --batch_size 8 --in_features 8192 --seq_len 512 --head_density 0.25
1501
+ '''
1502
+
1503
+
1504
+ if __name__ == "__main__":
1505
+ args = arg_parser()
1506
+
1507
+ max_seqlen = args.seq_len + 128
1508
+ max_batch_size = args.batch_size
1509
+ device = torch.device(f"cuda:{args.device}")
1510
+
1511
+ # simulates SelectiveMHA inference generation stage
1512
+ inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size)
1513
+ nheads = args.in_features // 128
1514
+ softmax_scale = 1 / (128 ** 0.5)
1515
+ rotary_emb_dim = 0
1516
+
1517
+ # build SelectiveMHA
1518
+ select_mha = SelectMHA(
1519
+ embed_dim=args.in_features,
1520
+ num_heads=nheads,
1521
+ num_heads_kv=None,
1522
+ causal=True,
1523
+ layer_idx=0,
1524
+ use_flash_attn=True,
1525
+ softmax_scale=softmax_scale,
1526
+ return_residual=False,
1527
+ rotary_emb_dim=rotary_emb_dim,
1528
+ device=device,
1529
+ dtype=torch.float16,
1530
+ )
1531
+
1532
+ standard_mha = SMHA(
1533
+ embed_dim=args.in_features,
1534
+ num_heads=nheads,
1535
+ num_heads_kv=None,
1536
+ causal=True,
1537
+ layer_idx=0,
1538
+ use_flash_attn=True,
1539
+ softmax_scale=softmax_scale,
1540
+ return_residual=False,
1541
+ rotary_emb_dim=rotary_emb_dim,
1542
+ device=device,
1543
+ dtype=torch.float16,
1544
+ )
1545
+ torch.cuda.empty_cache()
1546
+ torch.cuda.reset_max_memory_allocated()
1547
+
1548
+ with torch.no_grad():
1549
+ # prefill stage to generate kv cache for all batches
1550
+ og_x = torch.randn(args.batch_size, args.seq_len, args.in_features, device=device, dtype=torch.float16, requires_grad=False)
1551
+
1552
+ # out, time_ms = cuda_profiler(select_mha, og_x, inference_params=inference_params)
1553
+ # print(f"MHA Prefill time: {time_ms:.3f} ms")
1554
+ # out = select_mha(og_x, inference_params=inference_params)
1555
+
1556
+ # simulate kv cache, bug in flash_attn for larger batches
1557
+ kv = torch.randn(args.batch_size, args.seq_len, 2, nheads, 128, device=device, dtype=torch.float16, requires_grad=False)
1558
+ _ = _update_kv_cache(kv, inference_params, 0)
1559
+
1560
+ # increment the sequence length to move to the generation stage
1561
+ inference_params.seqlen_offset += args.seq_len
1562
+
1563
+ input_x = torch.randn(args.batch_size, 1, args.in_features, device=device, dtype=torch.float16, requires_grad=False)
1564
+ selected_heads = math.ceil(nheads * args.head_density)
1565
+
1566
+ # generate batch_head_idx for SelectiveMHA
1567
+ # batch_head_index = generate_BH_index(args.batch_size, nheads, selected_heads, device=device)
1568
+ batch_head_index = generate_random_BH_index(args.batch_size, nheads, selected_heads, device=device)
1569
+ # generatation stage Standard MHA
1570
+ out, standard_time_ms = cuda_profiler(standard_mha, input_x, inference_params=inference_params)
1571
+ print(f"Standard MHA time: {standard_time_ms:.3f} ms")
1572
+
1573
+ # generatation stage SelectiveMHA
1574
+ out, select_time_ms = cuda_profiler(select_mha, input_x, inference_params=inference_params, batch_head_idx=batch_head_index)
1575
+ print(f"SelectMHA time: {select_time_ms:.3f} ms")
1576
+
1577
+ speedup = standard_time_ms / select_time_ms
1578
+ print(f"Speedup: {speedup:.3f}")
1579
+
HybridTensor/modules/SelectiveMLP.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python -m HybridTensor.modules.SelectiveMLP --batch_size 8 --index_size 512
2
+ from typing import Optional
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch.distributed import ProcessGroup
10
+ import torch.distributed as dist
11
+
12
+ # import fused_dense_cuda # from apex
13
+
14
+ import fused_dense_lib as fused_dense_cuda
15
+
16
+ from flash_attn.utils.distributed import reduce_scatter, all_reduce
17
+ from einops import rearrange
18
+
19
+ # from HybridTensor.modules.MLP import SelectiveMLPFunc
20
+ from HybridTensor.modules.references.fused_dense import ColumnParallelLinear, RowParallelLinear, fused_mlp_func
21
+ from HybridTensor.modules.references.MLP import SelectiveMLPTriton
22
+ from HybridTensor.utils.utils import arg_parser, sparse_index
23
+ from HybridTensor.utils.profiling import cuda_profiler
24
+
25
+ # compiles the kernels for the first time, takes time
26
+ from HybridTensor.triton.gather_gemm_col import gather_matmul_col
27
+ from HybridTensor.triton.gather_gemm_row import gather_matmul_row
28
+
29
+ # needs to be compiled before running
30
+ from HybridTensor.triton.heuristics.gather_gemm_col_h import gather_matmul_col as gather_matmul_col_h
31
+ from HybridTensor.triton.heuristics.gather_gemm_row_h import gather_matmul_row as gather_matmul_row_h
32
+
33
+ # from HybridTensor.triton.cg_safe.gather_gemm_col_cg import gather_matmul_col
34
+ # from HybridTensor.triton.cg_safe.gather_gemm_row_cg import gather_matmul_row
35
+
36
+
37
+ def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, bias1 = None, bias2 = None, activation='relu', use_heuristic=True):
38
+ if use_heuristic:
39
+ out = gather_matmul_col_h(x, fc1_w, index_vec, bias = bias1, activations=activation)
40
+ out = gather_matmul_row_h(out, fc2_w, index_vec, bias = bias2)
41
+ else:
42
+ out = gather_matmul_col(x, fc1_w, index_vec, bias = bias1, activations=activation)
43
+ out = gather_matmul_row(out, fc2_w, index_vec, bias = bias2)
44
+ return out
45
+
46
+
47
+ # cg safe version
48
+ # def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, index_size, bias1 = None, bias2 = None, activation='relu', use_heuristic=True):
49
+ # out = gather_matmul_col(x, fc1_w, index_vec, index_size, bias = bias1, activations=activation)
50
+ # out = gather_matmul_row(out, fc2_w, index_vec, index_size, bias = bias2)
51
+ # return out
52
+
53
+ class MLPRouter(nn.Module):
54
+ def __init__(self, embed_dim, low_rank_dim, out_dim, act_th, device=None, dtype=None):
55
+ """
56
+ Initializes the MHARouter class.
57
+
58
+ Args:
59
+ embed_dim (int): Dimensionality of the input embeddings.
60
+ low_rank_dim (int): Dimensionality of the intermediate layer.
61
+ out_dim (int): Number of neurons.
62
+ """
63
+ super(MLPRouter, self).__init__()
64
+ factory_kwargs = {"device": device, "dtype": dtype}
65
+ self.fc1 = nn.Linear(embed_dim, low_rank_dim, bias=False, **factory_kwargs)
66
+ self.fc2 = nn.Linear(low_rank_dim, out_dim, bias=False, **factory_kwargs)
67
+ self.act_th = act_th
68
+ self.num_neurons = out_dim
69
+ self.largest = self.num_neurons + 1
70
+
71
+ def forward(self, x):
72
+ """
73
+ Forward pass of the MHARouter.
74
+
75
+ Args:
76
+ x (torch.Tensor): Input tensor of shape (batch_size, embed_dim).
77
+
78
+ Returns:
79
+ torch.Tensor: Output tensor of shape (batch_size, num_heads).
80
+ """
81
+ x = self.fc1(x)
82
+ x = self.fc2(x)
83
+ return x
84
+
85
+ def _select_neurons_topk(self, x, topk=None):
86
+ neurons = self.forward(x)
87
+
88
+ neurons_nonzero = torch.nn.ReLU()(neurons)
89
+ _, index_vec = neurons_nonzero.sum(dim=0).topk(topk, dim=0, sorted=False)
90
+ # index_vec, _ = index_vec.sort()
91
+ return index_vec
92
+
93
+ def _select_neurons(self, x, th=None):
94
+ '''
95
+ Threshold based selection of neurons, not CG safe
96
+ '''
97
+ if th is None:
98
+ th = self.act_th
99
+
100
+ neurons = self.forward(x)
101
+ activated = (neurons > th).sum(dim=0)
102
+ index_vec = activated.nonzero().flatten()
103
+ return index_vec
104
+
105
+ def _select_neurons_cuda_safe(self, x, th=None):
106
+ '''
107
+ This function is used with threshold and is used for CG safe version of the code
108
+ '''
109
+ if th is None:
110
+ th = self.act_th
111
+ neurons = self.forward(x)
112
+ activated = (neurons > th).sum(dim=0)
113
+
114
+ indices = torch.arange(self.num_neurons, device=activated.device)
115
+ selected = torch.where(activated > th, indices, torch.full_like(indices, self.largest))
116
+
117
+ index_vec, _ = torch.sort(selected)
118
+ index_size = ((index_vec < self.largest).sum()).to(torch.int32)
119
+
120
+ return index_size, index_vec
121
+
122
+
123
+
124
+ class ParallelMLPRouter(nn.Module):
125
+ """
126
+ Parallel Sparse Predictor for MHA layer.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ embed_dim,
132
+ low_rank_dim,
133
+ out_dim,
134
+ act_th,
135
+ process_group,
136
+ sequence_parallel=False,
137
+ device=None,
138
+ dtype=None,
139
+ ):
140
+ """
141
+ Initializes the ParallelMHARouter class.
142
+
143
+ Args:
144
+ embed_dim (int): Dimensionality of the input embeddings.
145
+ low_rank_dim (int): Dimensionality of the intermediate layer.
146
+ out_dim (int): Output dimensionality (typically number of neurons).
147
+ process_group (torch.distributed.ProcessGroup): Process group for parallelism.
148
+ sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False.
149
+ device (torch.device, optional): Device to run the module on. Defaults to None.
150
+ dtype (torch.dtype, optional): Data type of the module parameters. Defaults to None.
151
+ """
152
+ super(ParallelMLPRouter, self).__init__()
153
+ assert process_group is not None, "ParallelMHARouter requires a process group."
154
+
155
+ factory_kwargs = {"device": device, "dtype": dtype}
156
+ self.process_group = process_group
157
+ self.embed_dim = embed_dim
158
+ self.act_th = act_th
159
+
160
+ self.fc1 = nn.Linear(
161
+ embed_dim, low_rank_dim, bias=False, **factory_kwargs
162
+ )
163
+ self.fc2 = ColumnParallelLinear(
164
+ low_rank_dim,
165
+ out_dim,
166
+ process_group,
167
+ bias=False,
168
+ sequence_parallel=sequence_parallel,
169
+ **factory_kwargs,
170
+ )
171
+
172
+ # def _select_neurons(self, neurons, th=None):
173
+ # if th is None:
174
+ # th = self.act_th
175
+ # activated = (neurons > th).sum(dim=0)
176
+ # index_vec = activated.nonzero().flatten()
177
+ # return index_vec
178
+
179
+ def forward(self, x):
180
+ """
181
+ Forward pass of the ParallelMHARouter.
182
+
183
+ Args:
184
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, embed_dim).
185
+
186
+ Returns:
187
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, out_dim).
188
+ """
189
+ x = self.fc1(x)
190
+ x = self.fc2(x)
191
+ return x
192
+
193
+ def _select_neurons(self, x, th=None):
194
+ if th is None:
195
+ th = self.act_th
196
+
197
+ neurons = self.forward(x)
198
+ activated = (neurons > th).sum(dim=0)
199
+ index_vec = activated.nonzero().flatten()
200
+ return index_vec
201
+
202
+ def _select_neurons_topk(self, x, topk=None):
203
+ neurons = self.forward(x)
204
+
205
+ neurons_nonzero = torch.nn.ReLU()(neurons) #.squeeze(1)
206
+ # print(f"neurons_nonzero shape: {neurons_nonzero.shape}")
207
+ # print(f"Top k neurons: {topk}")
208
+ _, index_vec = neurons_nonzero.sum(dim=0).topk(topk, dim=0, sorted=False)
209
+ # index_vec, _ = index_vec.sort()
210
+ return index_vec
211
+
212
+ class SelectiveMLP(nn.Module):
213
+ def __init__(
214
+ self,
215
+ in_features,
216
+ hidden_features=None,
217
+ out_features=None,
218
+ activation='relu',
219
+ layer_idx=None,
220
+ bias1=True,
221
+ bias2=True,
222
+ return_residual=False,
223
+ checkpoint_lvl=0,
224
+ use_heuristic=True,
225
+ device=None,
226
+ dtype=None,
227
+ ):
228
+ factory_kwargs = {"device": device, "dtype": dtype}
229
+ super().__init__()
230
+ out_features = out_features if out_features is not None else in_features
231
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
232
+ self.return_residual = return_residual
233
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
234
+ self.activation = activation
235
+ self.activation_fn = nn.ReLU()
236
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
237
+ # self.fc2_weight_t = self.fc2.weight.t().contiguous()
238
+ self.fc2_weight_t = None
239
+ self.use_heuristic = use_heuristic
240
+
241
+ def _init_weights(self):
242
+ # if weights are updated, we need to update the transpose
243
+ self.fc2_weight_t = self.fc2.weight.t().contiguous()
244
+
245
+ def forward(self, x, index_vec=None, index_size=None):
246
+
247
+ if index_vec is not None:
248
+ # sparse forward,
249
+
250
+ # update on first run
251
+ if self.fc2_weight_t is None:
252
+ self.fc2_weight_t = self.fc2.weight.t().contiguous()
253
+
254
+ # Remove the original parameter to free memory.
255
+ self.fc2.weight = None
256
+ del self.fc2._parameters['weight']
257
+
258
+ x = x.view(-1, x.size(-1))
259
+ # x = x.squeeze(1)
260
+ y = SelectiveMLPFunc(x = x, fc1_w = self.fc1.weight,
261
+ fc2_w = self.fc2_weight_t, index_vec = index_vec,
262
+ bias1 = self.fc1.bias, bias2 = self.fc2.bias,
263
+ activation=self.activation, use_heuristic=self.use_heuristic)
264
+
265
+ else:
266
+ # dense forward
267
+
268
+ y = self.fc1(x)
269
+ y = self.activation_fn(y)
270
+
271
+ if self.fc2_weight_t is not None:
272
+ y = torch.matmul(y, self.fc2_weight_t)
273
+ else:
274
+ y = self.fc2(y)
275
+
276
+ return y if not self.return_residual else (y, x)
277
+
278
+ class ParallelSelectiveMLP(nn.Module):
279
+ def __init__(
280
+ self,
281
+ in_features,
282
+ hidden_features,
283
+ out_features=None,
284
+ activation="relu",
285
+ layer_idx=None,
286
+ process_group: ProcessGroup = None,
287
+ bias1=True,
288
+ bias2=True,
289
+ return_residual=False,
290
+ sequence_parallel=False,
291
+ use_heuristic=True,
292
+ checkpoint_lvl=0,
293
+ heuristic="auto",
294
+ device=None,
295
+ dtype=None,
296
+ ):
297
+ """
298
+ process_group is required. We're doing Tensor Parallel with sequence parallelism:
299
+ we do an all_gather of x before doing the matmul, gelu, then matmul.
300
+ Finally we do a reduce_scatter of the output.
301
+
302
+ checkpoint_lvl (increasing lvl means slower but more memory saving):
303
+ 0: no recomputation in the bwd
304
+ 1: recompute gelu_out in the bwd
305
+ 2: recompute pre_act and gelu_out in the bwd
306
+ heuristic:
307
+ -1: don't fuse gemm + gelu (separate kernel)
308
+ 0..4: use this heuristic for the algo section in the fused gemm + gelu
309
+ 'auto': heuristic will be picked automatically:
310
+ For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
311
+ For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
312
+ """
313
+ assert checkpoint_lvl in [0, 1, 2]
314
+ assert activation in ["gelu_approx", "relu"]
315
+ assert process_group is not None
316
+ # assert sp_kwargs != None, "sparse predictor parameters are not passed in."
317
+ factory_kwargs = {"device": device, "dtype": dtype}
318
+ super().__init__()
319
+ if out_features is None:
320
+ out_features = in_features
321
+ self.activation = activation
322
+ self.process_group = process_group
323
+ self.sequence_parallel = sequence_parallel
324
+ self.checkpoint_lvl = checkpoint_lvl
325
+ self.heuristic = heuristic
326
+ self.fc1 = ColumnParallelLinear(
327
+ in_features, hidden_features, process_group, bias=bias1, **factory_kwargs
328
+ )
329
+ self.fc2 = RowParallelLinear(
330
+ hidden_features, out_features, process_group, bias=bias2, **factory_kwargs
331
+ )
332
+ self.layer_idx = layer_idx
333
+
334
+ self.fc2_weight_t = self.register_buffer("fc2_weigth_t", None)
335
+ self.return_residual = return_residual
336
+ self.fc2_weight_t = None
337
+ self.use_heuristic = use_heuristic
338
+ self.reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
339
+
340
+ # self._init_weights()
341
+
342
+ def _init_weights(self):
343
+ # ffn2 weights needs to be in row major format to select from rows
344
+ self.fc2_weight_t = self.fc2.weight.t().contiguous()
345
+
346
+ def forward(self, x, residual = None, index_vec = None):
347
+
348
+ # do_token_generation = x.size(1) == 1
349
+ # index_vec = None
350
+ # with torch.cuda.stream(self.curr_stream):
351
+ if index_vec is not None:
352
+ # assert x.size(1) == 1
353
+ if self.fc2_weight_t is None:
354
+ self.fc2_weight_t = self.fc2.weight.t().contiguous()
355
+
356
+ x = x.view(-1, x.size(-1))
357
+ # x = rearrange(x, "b 1 d -> b d") # slightly more expensive to use rearrange
358
+
359
+ out = SelectiveMLPFunc(x = x, fc1_w = self.fc1.weight,
360
+ fc2_w = self.fc2_weight_t, index_vec = index_vec,
361
+ bias1 = self.fc1.bias, bias2 = self.fc2.bias,
362
+ activation=self.activation, use_heuristic=self.use_heuristic)
363
+ # out = rearrange(out, "b d -> b 1 d")
364
+ # out = out.view(-1, 1, out.size(-1))
365
+
366
+ else: # normal mlp
367
+ if self.heuristic == "auto":
368
+ dtype = (
369
+ x.dtype
370
+ if not torch.is_autocast_enabled()
371
+ else torch.get_autocast_gpu_dtype()
372
+ )
373
+ if self.activation == "gelu_approx":
374
+ cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
375
+ heuristic = (
376
+ 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
377
+ )
378
+ else:
379
+ heuristic = 0
380
+ else:
381
+ heuristic = self.heuristic
382
+ out = fused_mlp_func(
383
+ x,
384
+ self.fc1.weight,
385
+ self.fc2.weight,
386
+ self.fc1.bias,
387
+ self.fc2.bias,
388
+ activation=self.activation,
389
+ save_pre_act=self.training,
390
+ checkpoint_lvl=self.checkpoint_lvl,
391
+ heuristic=heuristic,
392
+ process_group=self.process_group,
393
+ sequence_parallel=self.sequence_parallel,
394
+ )
395
+
396
+ if self.process_group.size() > 1:
397
+ # out = self.reduce_fn(out, self.process_group) # has some overhead,
398
+ dist.all_reduce(out, op=dist.ReduceOp.SUM, group=self.process_group)
399
+
400
+ return out if not self.return_residual else (out, x)
401
+ # return out
402
+
403
+ def sp_forward(self, x, residual = None, index_vec = None):
404
+ if self.heuristic == "auto":
405
+ dtype = (
406
+ x.dtype
407
+ if not torch.is_autocast_enabled()
408
+ else torch.get_autocast_gpu_dtype()
409
+ )
410
+ if self.activation == "gelu_approx":
411
+ cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
412
+ heuristic = (
413
+ 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
414
+ )
415
+ else:
416
+ heuristic = 0
417
+ else:
418
+ heuristic = self.heuristic
419
+ curr_stream = torch.cuda.current_stream()
420
+ do_token_generation = x.size(1) == 1
421
+ # mlp_logit = None
422
+
423
+ # with torch.cuda.stream(self.curr_stream):
424
+ if index_vec != None:
425
+ assert x.size(1) == 1
426
+
427
+ if self.fc2_weight_t is None:
428
+ self.fc2_weight_t = self.fc2.weight.t().contiguous()
429
+
430
+ out = SelectiveMLPFunc(
431
+ rearrange(x, "b 1 d -> b d"),
432
+ self.fc1.weight,
433
+ self.fc2_weight_t,
434
+ index_vec,
435
+ self.fc1.bias,
436
+ self.fc2.bias,
437
+ activation=self.activation,
438
+ )
439
+ out = rearrange(out, "b d -> b 1 d")
440
+ else:
441
+ out = fused_mlp_func(
442
+ x,
443
+ self.fc1.weight,
444
+ self.fc2.weight,
445
+ self.fc1.bias,
446
+ self.fc2.bias,
447
+ activation=self.activation,
448
+ save_pre_act=self.training,
449
+ checkpoint_lvl=self.checkpoint_lvl,
450
+ heuristic=heuristic,
451
+ process_group=self.process_group,
452
+ sequence_parallel=self.sequence_parallel,
453
+ )
454
+
455
+
456
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
457
+ if self.sp_router:
458
+ curr_stream.record_event(self.event_mlp)
459
+
460
+ # handle = torch.distributed.all_reduce(out, op=torch.distributed.ReduceOp.SUM, group=self.process_group, async_op=True)
461
+ out = reduce_fn(out, self.process_group)
462
+
463
+
464
+ if self.sp_router:
465
+ with torch.cuda.stream(self.sp_stream):
466
+ self.sp_stream.wait_event(self.event_mlp)
467
+ if do_token_generation:
468
+ mlp_logit = self.sp(rearrange(residual, "b 1 d -> b d"))
469
+ self.sp_stream.record_event(self.event_mlp_sp)
470
+
471
+ # check this again, we might not have to synchronize here, we can synchronize in the next layer
472
+ curr_stream.wait_event(self.event_mlp_sp)
473
+
474
+ return out
475
+
476
+ class SimpleMLP(nn.Module):
477
+ def __init__(self, in_features, hidden_features, out_features, bias=False, activation="relu"):
478
+ super().__init__()
479
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
480
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
481
+ self.activation = activation
482
+
483
+ def forward(self, x):
484
+ x = F.relu(self.fc1(x))
485
+ x = self.fc2(x)
486
+ return x
487
+
488
+ if __name__ == "__main__":
489
+ args = arg_parser()
490
+
491
+ bias = True if args.bias > 0 else False
492
+ x = torch.randn(args.batch_size, args.in_features, device="cuda", dtype=torch.float16)
493
+ index_vec, _ = sparse_index(args.index_size, args.in_features*4)
494
+
495
+ '''
496
+ selective_mlp = SelectiveMLPTriton(args.in_features, args.hidden_features, bias=bias, device="cuda", dtype=torch.float16, activation="relu")
497
+
498
+ out, mlp_time = cuda_profiler(selective_mlp, x, index_vec)
499
+
500
+ out_col, col_time = cuda_profiler(gather_matmul_col, x, selective_mlp.fc1_w, index_vec, activations=selective_mlp.activation)
501
+ out_row, row_time = cuda_profiler(gather_matmul_row, out_col, selective_mlp.fc2_w, index_vec)
502
+ sum_time = col_time + row_time
503
+
504
+ print(f"Index size {args.index_size}, Activated {args.index_size/(args.in_features * 4)*100}% neurons")
505
+
506
+ print(f"Gather Col Time: {col_time} ms")
507
+ print(f"Gather Row Time: {row_time} ms")
508
+ # print(f"Sum Time: {sum_time} ms")
509
+
510
+ print(f"SelectiveMLP Time: {mlp_time} ms")
511
+ '''
512
+
513
+ in_features = args.in_features
514
+ hidden_features = in_features * 4
515
+ out_features = in_features
516
+ device = torch.device("cuda")
517
+
518
+ model = SelectiveMLP(
519
+ in_features, hidden_features, out_features, device=device, dtype=torch.float16, activation="relu", use_heuristic=True
520
+ ).to(device)
521
+
522
+ router = MLPRouter(in_features, 1024, hidden_features, act_th = 0.5, device=device, dtype=torch.float16).to(device)
523
+
524
+ # Warm-up GPU
525
+ def warmup():
526
+ for _ in range(10):
527
+ _ = model(x, index_vec)
528
+ _ = model(x, None)
529
+ _ = router._select_neurons_topk(x, args.index_size)
530
+
531
+ warmup()
532
+
533
+ # Measure SelectiveMLPFunc speed
534
+ _, router_time = cuda_profiler(router._select_neurons_topk, x, args.index_size)
535
+ _, selective_time = cuda_profiler(model, x, index_vec)
536
+ # Measure dense forward speed
537
+ _, dense_time = cuda_profiler(model, x, None)
538
+
539
+ print(f"Router time per run: {router_time:.6f} ms")
540
+ print(f"SelectiveMLPFunc time per run: {selective_time:.6f} ms")
541
+ print(f"Dense forward time per run: {dense_time:.6f} ms")
542
+ print(f"Speedup: {dense_time / selective_time:.2f}x")
543
+ router_selective_time = router_time + selective_time
544
+ print(f"Router + SelectiveMLPFunc time per run: {router_selective_time:.6f} ms")
545
+ print(f"Speedup: {dense_time / router_selective_time:.2f}x")
546
+ ############################################
547
+ # CUDA Graph capture tests for the MLP model
548
+ ############################################
549
+ print("\n=== CUDA Graph Tests ===")
550
+ # --- Selective forward (sparse mode) ---
551
+ print("Testing CUDA Graph for Selective forward (with index_vec)...")
552
+ static_x = x.clone()
553
+ static_index_vec = index_vec.clone()
554
+ # Warm-up run to allocate memory
555
+ static_out_sel = model(static_x, index_vec=static_index_vec)
556
+ torch.cuda.synchronize()
557
+
558
+ # Capture on a non-default stream
559
+ capture_stream = torch.cuda.Stream()
560
+ with torch.cuda.stream(capture_stream):
561
+ g_sel = torch.cuda.CUDAGraph()
562
+ g_sel.capture_begin()
563
+ static_out_sel = model(static_x, index_vec=static_index_vec)
564
+ g_sel.capture_end()
565
+ torch.cuda.synchronize()
566
+
567
+ # Replay and check accuracy
568
+ g_sel.replay()
569
+ torch.cuda.synchronize()
570
+ cuda_sel_out = static_out_sel.clone()
571
+ regular_sel_out = model(x, index_vec=index_vec)
572
+ if torch.allclose(cuda_sel_out, regular_sel_out, atol=1e-3):
573
+ print("Selective forward CUDA Graph output matches regular output")
574
+ else:
575
+ print("Selective forward CUDA Graph output does NOT match regular output")
576
+
577
+ def replay_sel():
578
+ g_sel.replay()
579
+ _, selective_time_cuda = cuda_profiler(replay_sel)
580
+ print(f"Selective forward CUDA Graph time per run: {selective_time_cuda:.6f} ms")
HybridTensor/modules/SelectiveRouters.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ from collections import OrderedDict
5
+
6
+ def create_mlp_router_state_dict(router_files_dir):
7
+ """
8
+ Loads all mlp_router weight files from the specified directory and creates a router_state_dict
9
+ with keys formatted as 'transformer.layers.{layer_num}.mlp_router.{param_name}'.
10
+
11
+ Args:
12
+ router_files_dir (str): Path to the directory containing mlp_router_*.pt files.
13
+
14
+ Returns:
15
+ OrderedDict: A state dictionary suitable for loading into a transformer model.
16
+ """
17
+ # Regular expression to extract layer number from filename
18
+ router_file_pattern = re.compile(r'mlp_router_(\d+)-[\d.]+-[\d.]+-[\d.]+\.pt$')
19
+
20
+ router_state_dict = OrderedDict()
21
+
22
+ # List all files in the directory
23
+ try:
24
+ all_files = os.listdir(router_files_dir)
25
+ except FileNotFoundError:
26
+ print(f"Error: Directory '{router_files_dir}' does not exist.")
27
+ return None
28
+
29
+ # Filter files matching the pattern
30
+ router_files = [f for f in all_files if router_file_pattern.match(f)]
31
+
32
+ if not router_files:
33
+ print(f"No router files found in directory '{router_files_dir}'.")
34
+ return None
35
+
36
+ for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))):
37
+ match = router_file_pattern.match(file_name)
38
+ if not match:
39
+ print(f"Skipping file '{file_name}' as it does not match the pattern.")
40
+ continue
41
+
42
+ layer_num = int(match.group(1))
43
+ file_path = os.path.join(router_files_dir, file_name)
44
+
45
+ try:
46
+ # Load the router's state dict
47
+ router_weights = torch.load(file_path, map_location='cpu')
48
+ if not isinstance(router_weights, dict):
49
+ print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.")
50
+ continue
51
+ except Exception as e:
52
+ print(f"Error loading '{file_path}': {e}")
53
+ continue
54
+
55
+ # Iterate through each parameter in the router's state dict
56
+ for param_name, param_tensor in router_weights.items():
57
+ # Construct the new key
58
+ new_key = f"transformer.layers.{layer_num}.mlp_router.{param_name}"
59
+ router_state_dict[new_key] = param_tensor
60
+
61
+ # print(f"Loaded router for layer {layer_num} from '{file_name}'.")
62
+
63
+ print(f"Total routers loaded: {len(router_state_dict) // 2}") # Assuming 4 params per router (weight & bias for 2 layers)
64
+ return router_state_dict
65
+
66
+
67
+ def create_attn_router_state_dict(router_files_dir):
68
+ """
69
+ Loads all attn_router weight files from the specified directory and creates a router_state_dict
70
+ with keys formatted as 'transformer.layers.{layer_num}.mha_router.{param_name}'.
71
+
72
+ Args:
73
+ router_files_dir (str): Path to the directory containing attn_router_*.pt files.
74
+
75
+ Returns:
76
+ OrderedDict: A state dictionary suitable for loading into a transformer model.
77
+ """
78
+ # Regular expression to extract layer number from filename
79
+ # Pattern: attn_router_{layer_num}-{value1}-{value2}.pt
80
+ router_file_pattern = re.compile(r'attn_router_(\d+)-[\d.]+-[\d.]+\.pt$')
81
+
82
+ router_state_dict = OrderedDict()
83
+
84
+ # List all files in the directory
85
+ try:
86
+ all_files = os.listdir(router_files_dir)
87
+ except FileNotFoundError:
88
+ print(f"Error: Directory '{router_files_dir}' does not exist.")
89
+ return None
90
+
91
+ # Filter files matching the pattern
92
+ router_files = [f for f in all_files if router_file_pattern.match(f)]
93
+
94
+ if not router_files:
95
+ print(f"No attn_router files found in directory '{router_files_dir}'.")
96
+ return None
97
+
98
+ # To handle potential duplicates, keep track of loaded layer numbers
99
+ loaded_layers = set()
100
+
101
+ for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))):
102
+ match = router_file_pattern.match(file_name)
103
+ if not match:
104
+ print(f"Skipping file '{file_name}' as it does not match the pattern.")
105
+ continue
106
+
107
+ layer_num = int(match.group(1))
108
+ if layer_num in loaded_layers:
109
+ print(f"Warning: Multiple router files found for layer {layer_num}. Skipping '{file_name}'.")
110
+ continue # Skip duplicate layers
111
+
112
+ file_path = os.path.join(router_files_dir, file_name)
113
+
114
+ try:
115
+ # Load the router's state dict
116
+ router_weights = torch.load(file_path, map_location='cpu')
117
+ if not isinstance(router_weights, dict):
118
+ print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.")
119
+ continue
120
+ except Exception as e:
121
+ print(f"Error loading '{file_path}': {e}")
122
+ continue
123
+
124
+ # Iterate through each parameter in the router's state dict
125
+ for param_name, param_tensor in router_weights.items():
126
+ # Construct the new key
127
+ new_key = f"transformer.layers.{layer_num}.mha_router.{param_name}"
128
+ router_state_dict[new_key] = param_tensor
129
+
130
+ loaded_layers.add(layer_num)
131
+ # print(f"Loaded MHA router for layer {layer_num} from '{file_name}'.")
132
+
133
+ print(f"Total MHA routers loaded: {len(loaded_layers)}")
134
+ return router_state_dict
135
+
136
+
HybridTensor/modules/__init__.py ADDED
File without changes
HybridTensor/modules/__pycache__/MLP.cpython-39.pyc ADDED
Binary file (5.3 kB). View file
 
HybridTensor/modules/__pycache__/ParallelMLP.cpython-39.pyc ADDED
Binary file (5.61 kB). View file
 
HybridTensor/modules/__pycache__/SelectiveBlock.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-310.pyc ADDED
Binary file (6.56 kB). View file
 
HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-39.pyc ADDED
Binary file (6.77 kB). View file
 
HybridTensor/modules/__pycache__/SelectiveMHA.cpython-310.pyc ADDED
Binary file (23.3 kB). View file
 
HybridTensor/modules/__pycache__/SelectiveMHA.cpython-39.pyc ADDED
Binary file (30.1 kB). View file
 
HybridTensor/modules/__pycache__/SelectiveMLP.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
HybridTensor/modules/__pycache__/SelectiveMLP.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
HybridTensor/modules/__pycache__/SelectiveRouters.cpython-310.pyc ADDED
Binary file (3.96 kB). View file
 
HybridTensor/modules/__pycache__/SelectiveRouters.cpython-39.pyc ADDED
Binary file (4 kB). View file
 
HybridTensor/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file