Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- HybridTensor/__init__.py +0 -0
- HybridTensor/__pycache__/__init__.cpython-310.pyc +0 -0
- HybridTensor/__pycache__/__init__.cpython-39.pyc +0 -0
- HybridTensor/benchmarks/generation/__pycache__/gen_util.cpython-39.pyc +0 -0
- HybridTensor/benchmarks/generation/__pycache__/llama_sparse_generation.cpython-39.pyc +0 -0
- HybridTensor/benchmarks/generation/__pycache__/model_sparse_generation.cpython-39.pyc +0 -0
- HybridTensor/benchmarks/generation/__pycache__/opt_gen_tp.cpython-39.pyc +0 -0
- HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-310.pyc +0 -0
- HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-39.pyc +0 -0
- HybridTensor/benchmarks/generation/__pycache__/opt_sparse_gen_tp.cpython-39.pyc +0 -0
- HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-310.pyc +0 -0
- HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-39.pyc +0 -0
- HybridTensor/benchmarks/generation/gen_util.py +38 -0
- HybridTensor/benchmarks/generation/model_sparse_generation.py +71 -0
- HybridTensor/benchmarks/generation/opt_gen_tp.py +133 -0
- HybridTensor/benchmarks/generation/opt_generation.py +289 -0
- HybridTensor/benchmarks/generation/opt_sparse_gen_tp.py +112 -0
- HybridTensor/benchmarks/generation/opt_sparse_generation.py +182 -0
- HybridTensor/benchmarks/model_eval.py +313 -0
- HybridTensor/benchmarks/model_perplexity.py +165 -0
- HybridTensor/benchmarks/opt_attn_sparse_topk_perplexity.py +264 -0
- HybridTensor/benchmarks/select_block_decode.py +218 -0
- HybridTensor/models/__pycache__/create_sparse_model.cpython-310.pyc +0 -0
- HybridTensor/models/__pycache__/create_sparse_model.cpython-39.pyc +0 -0
- HybridTensor/models/__pycache__/helper.cpython-310.pyc +0 -0
- HybridTensor/models/__pycache__/helper.cpython-39.pyc +0 -0
- HybridTensor/models/__pycache__/llama.cpython-39.pyc +0 -0
- HybridTensor/models/__pycache__/opt.cpython-310.pyc +0 -0
- HybridTensor/models/__pycache__/opt.cpython-39.pyc +0 -0
- HybridTensor/models/create_sparse_model.py +854 -0
- HybridTensor/models/helper.py +125 -0
- HybridTensor/models/llama.py +74 -0
- HybridTensor/models/opt.py +229 -0
- HybridTensor/modules/SelectiveBlock.py +960 -0
- HybridTensor/modules/SelectiveMHA.py +1579 -0
- HybridTensor/modules/SelectiveMLP.py +580 -0
- HybridTensor/modules/SelectiveRouters.py +136 -0
- HybridTensor/modules/__init__.py +0 -0
- HybridTensor/modules/__pycache__/MLP.cpython-39.pyc +0 -0
- HybridTensor/modules/__pycache__/ParallelMLP.cpython-39.pyc +0 -0
- HybridTensor/modules/__pycache__/SelectiveBlock.cpython-39.pyc +0 -0
- HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-310.pyc +0 -0
- HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-39.pyc +0 -0
- HybridTensor/modules/__pycache__/SelectiveMHA.cpython-310.pyc +0 -0
- HybridTensor/modules/__pycache__/SelectiveMHA.cpython-39.pyc +0 -0
- HybridTensor/modules/__pycache__/SelectiveMLP.cpython-310.pyc +0 -0
- HybridTensor/modules/__pycache__/SelectiveMLP.cpython-39.pyc +0 -0
- HybridTensor/modules/__pycache__/SelectiveRouters.cpython-310.pyc +0 -0
- HybridTensor/modules/__pycache__/SelectiveRouters.cpython-39.pyc +0 -0
- HybridTensor/modules/__pycache__/__init__.cpython-310.pyc +0 -0
HybridTensor/__init__.py
ADDED
|
File without changes
|
HybridTensor/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (190 Bytes). View file
|
|
|
HybridTensor/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
HybridTensor/benchmarks/generation/__pycache__/gen_util.cpython-39.pyc
ADDED
|
Binary file (869 Bytes). View file
|
|
|
HybridTensor/benchmarks/generation/__pycache__/llama_sparse_generation.cpython-39.pyc
ADDED
|
Binary file (2.34 kB). View file
|
|
|
HybridTensor/benchmarks/generation/__pycache__/model_sparse_generation.cpython-39.pyc
ADDED
|
Binary file (2.68 kB). View file
|
|
|
HybridTensor/benchmarks/generation/__pycache__/opt_gen_tp.cpython-39.pyc
ADDED
|
Binary file (3.75 kB). View file
|
|
|
HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-310.pyc
ADDED
|
Binary file (5.22 kB). View file
|
|
|
HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-39.pyc
ADDED
|
Binary file (6.12 kB). View file
|
|
|
HybridTensor/benchmarks/generation/__pycache__/opt_sparse_gen_tp.cpython-39.pyc
ADDED
|
Binary file (3.44 kB). View file
|
|
|
HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-310.pyc
ADDED
|
Binary file (3.36 kB). View file
|
|
|
HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-39.pyc
ADDED
|
Binary file (4.56 kB). View file
|
|
|
HybridTensor/benchmarks/generation/gen_util.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
|
| 6 |
+
def tokenize_dataset(dataset, tokenizer):
|
| 7 |
+
# Tokenize and concatenate all texts (without adding special tokens)
|
| 8 |
+
all_tokens = []
|
| 9 |
+
for example in dataset:
|
| 10 |
+
tokens = tokenizer(example["text"], add_special_tokens=False)["input_ids"]
|
| 11 |
+
all_tokens.extend(tokens)
|
| 12 |
+
return all_tokens
|
| 13 |
+
|
| 14 |
+
def get_random_batch(tokens, batch_size, seq_length):
|
| 15 |
+
total = len(tokens)
|
| 16 |
+
batch = []
|
| 17 |
+
for _ in range(batch_size):
|
| 18 |
+
start = random.randint(0, total - seq_length)
|
| 19 |
+
batch.append(tokens[start : start + seq_length])
|
| 20 |
+
return torch.tensor(batch)
|
| 21 |
+
|
| 22 |
+
'''
|
| 23 |
+
# Load dataset and tokenizer
|
| 24 |
+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
| 25 |
+
model_name = "facebook/opt-6.7b"
|
| 26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 27 |
+
|
| 28 |
+
tokens = tokenize_dataset(dataset, tokenizer)
|
| 29 |
+
|
| 30 |
+
# Define parameters
|
| 31 |
+
batch_size = 8
|
| 32 |
+
seq_length = 2000
|
| 33 |
+
|
| 34 |
+
random_batch = get_random_batch(tokens, batch_size, seq_length)
|
| 35 |
+
print("Batch shape:", random_batch.shape) # Expected: (8, 128)
|
| 36 |
+
|
| 37 |
+
'''
|
| 38 |
+
|
HybridTensor/benchmarks/generation/model_sparse_generation.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
from HybridTensor.utils.utils import _get_device
|
| 5 |
+
from HybridTensor.utils.activations import MODELS
|
| 6 |
+
from HybridTensor.models.opt import build_sparse_opt
|
| 7 |
+
from HybridTensor.models.llama import build_sparse_llama
|
| 8 |
+
from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def update_router_config(model, num_layers, mlp_topk_lookup, attn_topk):
|
| 13 |
+
for i in range(num_layers):
|
| 14 |
+
if mlp_topk_lookup is not None:
|
| 15 |
+
model.transformer.layers[i].mlp_topk = mlp_topk_lookup[i]
|
| 16 |
+
# model.transformer.layers[i].mlp_topk = 512
|
| 17 |
+
model.transformer.layers[i].mha_router.topk = attn_topk
|
| 18 |
+
|
| 19 |
+
# dense attention in layer 0
|
| 20 |
+
model.transformer.layers[0].mha_router.topk = 1.0
|
| 21 |
+
|
| 22 |
+
def arg_parser():
|
| 23 |
+
parser = argparse.ArgumentParser(description='Inference benchmarking')
|
| 24 |
+
parser.add_argument('--batch_size', type=int, default=16)
|
| 25 |
+
parser.add_argument('--model_index', type=int, default=5)
|
| 26 |
+
parser.add_argument('--print_results', type=bool, default=True)
|
| 27 |
+
parser.add_argument('--iterations', type=int, default=1)
|
| 28 |
+
parser.add_argument('--gpu', type=int, default=0)
|
| 29 |
+
parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model')
|
| 30 |
+
parser.add_argument('--mlp_ckpt_dir', type=str, default='/home/grads/s/<name>/nvme/HybridTensor/checkpoint/opt-6.7b-routers/mlp')
|
| 31 |
+
parser.add_argument('--attn_ckpt_dir', type=str, default='/home/grads/s/<name>/nvme/HybridTensor/checkpoint/opt-6.7b-routers/mha_linear')
|
| 32 |
+
parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router/opt-6.7b')
|
| 33 |
+
parser.add_argument('--delta', type=int, default=256, help='Delta value for MLP topk calculation')
|
| 34 |
+
parser.add_argument('--use_cuda_graph', type=bool, default=False, help='Use CUDA graph for inference')
|
| 35 |
+
|
| 36 |
+
return parser.parse_args()
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
args = arg_parser()
|
| 40 |
+
model_name = MODELS[args.model_index-1]
|
| 41 |
+
print(f"Model name: {model_name}")
|
| 42 |
+
dtype = torch.float16
|
| 43 |
+
device= _get_device(args.gpu)
|
| 44 |
+
|
| 45 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 46 |
+
if "llama" in model_name:
|
| 47 |
+
model = build_sparse_llama(args, model_name,
|
| 48 |
+
args.attn_ckpt_dir,
|
| 49 |
+
device = device, dtype=dtype)
|
| 50 |
+
update_router_config(model, model.config.n_layer, None, args.attn_topk) # this sets the router config for all layers using a single config
|
| 51 |
+
|
| 52 |
+
else:
|
| 53 |
+
mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, args.batch_size)
|
| 54 |
+
model = build_sparse_opt(args, model_name,
|
| 55 |
+
args.mlp_ckpt_dir,
|
| 56 |
+
args.attn_ckpt_dir,
|
| 57 |
+
device = device, dtype=dtype)
|
| 58 |
+
update_router_config(model, model.config.n_layer, mlp_topk_lookup, args.attn_topk) # this sets the router config for all layers using a single config
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
model.eval()
|
| 62 |
+
print(model)
|
| 63 |
+
|
| 64 |
+
# test input
|
| 65 |
+
input_text = "Once upon a time in a land far, far away, there lived a"
|
| 66 |
+
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
| 67 |
+
|
| 68 |
+
# Generate output
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
output = model.generate(input_ids, max_length=50)
|
| 71 |
+
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
HybridTensor/benchmarks/generation/opt_gen_tp.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.models.opt import OPTConfig
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
from flash_attn.models.opt import opt_config_to_gpt2_config
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import argparse
|
| 8 |
+
from apex.transformer import parallel_state
|
| 9 |
+
|
| 10 |
+
from HybridTensor.utils.utils import arg_parser, _get_device
|
| 11 |
+
from HybridTensor.utils.activations import OPT_MODELS
|
| 12 |
+
from HybridTensor.models.opt import SparseConfig, build_sparse_opt, build_dense_opt
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def initialize_distributed_environment():
|
| 16 |
+
# Set environment variables for NCCL
|
| 17 |
+
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
| 18 |
+
os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0"
|
| 19 |
+
|
| 20 |
+
# Initialize the distributed process group
|
| 21 |
+
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
| 22 |
+
|
| 23 |
+
# Set the device based on the rank of the current process
|
| 24 |
+
device = f"cuda:{torch.distributed.get_rank()}"
|
| 25 |
+
world_size = torch.distributed.get_world_size()
|
| 26 |
+
|
| 27 |
+
# Set the current CUDA device to avoid operations being executed on the wrong GPU
|
| 28 |
+
torch.cuda.set_device(device)
|
| 29 |
+
|
| 30 |
+
# You can return device, world_size, and any other relevant information
|
| 31 |
+
return device, world_size
|
| 32 |
+
|
| 33 |
+
def _turn_bias_off(model, num_layers):
|
| 34 |
+
for i in range(num_layers):
|
| 35 |
+
model.transformer.layers[i].mlp.fc1.bias = None
|
| 36 |
+
model.transformer.layers[i].mlp.fc2.bias = None
|
| 37 |
+
|
| 38 |
+
def arg_parser():
|
| 39 |
+
parser = argparse.ArgumentParser(description='Inference benchmarking')
|
| 40 |
+
parser.add_argument('--batch_size', type=int, default=128)
|
| 41 |
+
parser.add_argument('--model_index', type=int, default=5)
|
| 42 |
+
parser.add_argument('--seq_len', type=int, default=25)
|
| 43 |
+
parser.add_argument('--index_size', type=int, default=8192)
|
| 44 |
+
parser.add_argument('--head_density', type=float, default=0.25)
|
| 45 |
+
parser.add_argument('--print_results', type=bool, default=True)
|
| 46 |
+
parser.add_argument('--iterations', type=int, default=2)
|
| 47 |
+
parser.add_argument('--check_results', type=bool, default=False)
|
| 48 |
+
parser.add_argument('--results_dir', type=str, default='results')
|
| 49 |
+
parser.add_argument('--gpu', type=int, default=0)
|
| 50 |
+
parser.add_argument('--bias', type=bool, default=False)
|
| 51 |
+
parser.add_argument('--mlp_ckpt_dir', type=str, default='/home/grads/s/<name>/nvme/HybridTensor/checkpoint/opt-6.7b-routers/mlp')
|
| 52 |
+
parser.add_argument('--attn_ckpt_dir', type=str, default='/home/grads/s/<name>/nvme/HybridTensor/checkpoint/opt-6.7b-routers/mha_linear')
|
| 53 |
+
|
| 54 |
+
return parser.parse_args()
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
|
| 58 |
+
args = arg_parser()
|
| 59 |
+
model_name = OPT_MODELS[args.model_index-1]
|
| 60 |
+
|
| 61 |
+
device, world_size = initialize_distributed_environment()
|
| 62 |
+
dtype = torch.float16
|
| 63 |
+
|
| 64 |
+
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
| 65 |
+
rank = parallel_state.get_tensor_model_parallel_rank()
|
| 66 |
+
process_group = parallel_state.get_tensor_model_parallel_group()
|
| 67 |
+
|
| 68 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 69 |
+
# model = build_sparse_opt(model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype, process_group = process_group, world_size = world_size, rank = rank)
|
| 70 |
+
model = build_dense_opt(model_name, process_group = process_group, world_size = world_size, rank = rank, device = device, dtype=dtype)
|
| 71 |
+
model.eval()
|
| 72 |
+
# if rank == 0:
|
| 73 |
+
# print(model)
|
| 74 |
+
|
| 75 |
+
# input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "]
|
| 76 |
+
input_texts = ["In a distant galaxy, a spaceship"]
|
| 77 |
+
tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(device)
|
| 78 |
+
input_ids=tokenized_inputs["input_ids"]
|
| 79 |
+
# input_ids = tokenizer("Hello, my dog is cute and", return_tensors="pt").input_ids.to(device=device)
|
| 80 |
+
|
| 81 |
+
max_length = args.seq_len
|
| 82 |
+
position_ids = None
|
| 83 |
+
eos_token_id = tokenizer.eos_token_id
|
| 84 |
+
num_layers = model.config.n_layer
|
| 85 |
+
|
| 86 |
+
# turn bias off for mlp layers
|
| 87 |
+
if not args.bias:
|
| 88 |
+
_turn_bias_off(model, num_layers)
|
| 89 |
+
|
| 90 |
+
_ = model.generate(
|
| 91 |
+
input_ids=input_ids,
|
| 92 |
+
max_length=max_length,
|
| 93 |
+
eos_token_id=eos_token_id,
|
| 94 |
+
return_dict_in_generate=True,
|
| 95 |
+
output_scores=True,
|
| 96 |
+
enable_timing=False,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 100 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 101 |
+
|
| 102 |
+
start_event.record()
|
| 103 |
+
|
| 104 |
+
for i in range(args.iterations):
|
| 105 |
+
out = model.generate(
|
| 106 |
+
input_ids=input_ids,
|
| 107 |
+
max_length=max_length,
|
| 108 |
+
eos_token_id=eos_token_id,
|
| 109 |
+
return_dict_in_generate=True,
|
| 110 |
+
output_scores=True,
|
| 111 |
+
enable_timing=False,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
end_event.record()
|
| 115 |
+
|
| 116 |
+
torch.cuda.synchronize()
|
| 117 |
+
|
| 118 |
+
# print(tokenizer.batch_decode(out.sequences.tolist()))
|
| 119 |
+
|
| 120 |
+
if rank == 0:
|
| 121 |
+
elapsed_time = start_event.elapsed_time(end_event) / args.iterations
|
| 122 |
+
print(f"Average time per genearation : {elapsed_time} ms")
|
| 123 |
+
|
| 124 |
+
# Compute throughput and latency per token
|
| 125 |
+
num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
|
| 126 |
+
throughput = num_tokens_generated / (elapsed_time / 1000) # tokens per second
|
| 127 |
+
latency_per_token = elapsed_time / num_tokens_generated # ms per token
|
| 128 |
+
|
| 129 |
+
print(f"Number of tokens generated: {num_tokens_generated}")
|
| 130 |
+
print(f"Throughput: {throughput} tokens/second")
|
| 131 |
+
print(f"Latency per token: {latency_per_token} ms")
|
| 132 |
+
print(tokenizer.batch_decode(out.sequences.tolist()))
|
| 133 |
+
|
HybridTensor/benchmarks/generation/opt_generation.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
import torch
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
from HybridTensor.benchmarks.generation.gen_util import tokenize_dataset, get_random_batch
|
| 11 |
+
from HybridTensor.utils.activations import OPT_MODELS
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
|
| 14 |
+
from flash_attn.models.gpt import GPTLMHeadModel
|
| 15 |
+
from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
|
| 16 |
+
from flash_attn.utils.generation import update_graph_cache
|
| 17 |
+
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
| 18 |
+
from transformers import AutoTokenizer, OPTConfig
|
| 19 |
+
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
| 20 |
+
|
| 21 |
+
def test_opt_generation(model_name):
|
| 22 |
+
"""Check that our implementation of OPT generation matches the HF implementation:
|
| 23 |
+
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
| 24 |
+
the HF scores in fp32.
|
| 25 |
+
"""
|
| 26 |
+
print(f"\nMODEL: {model_name}")
|
| 27 |
+
verbose = False
|
| 28 |
+
dtype = torch.float16
|
| 29 |
+
device = "cuda"
|
| 30 |
+
rtol, atol = 3e-3, 3e-1
|
| 31 |
+
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
|
| 32 |
+
# Only prenorm supports residual_in_fp32
|
| 33 |
+
config.residual_in_fp32 = getattr(config, "prenorm", True)
|
| 34 |
+
config.use_flash_attn = True
|
| 35 |
+
config.fused_bias_fc = True
|
| 36 |
+
config.fused_mlp = True
|
| 37 |
+
config.fused_dropout_add_ln = True
|
| 38 |
+
|
| 39 |
+
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
| 40 |
+
model.eval()
|
| 41 |
+
|
| 42 |
+
torch.manual_seed(0)
|
| 43 |
+
# OPT tokenizer requires use_fast=False
|
| 44 |
+
# https://huggingface.co/docs/transformers/model_doc/opt
|
| 45 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
| 46 |
+
eos_token_id = tokenizer.eos_token_id
|
| 47 |
+
|
| 48 |
+
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
|
| 49 |
+
device=device
|
| 50 |
+
)
|
| 51 |
+
max_length = 25
|
| 52 |
+
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
|
| 53 |
+
# max_length = input_ids.shape[1] + 40
|
| 54 |
+
|
| 55 |
+
# Slow generation for reference
|
| 56 |
+
sequences = []
|
| 57 |
+
scores = []
|
| 58 |
+
cur_input_ids = input_ids
|
| 59 |
+
with torch.inference_mode():
|
| 60 |
+
scores.append(model(cur_input_ids).logits[:, -1])
|
| 61 |
+
sequences.append(scores[-1].argmax(dim=-1))
|
| 62 |
+
for _ in range(input_ids.shape[1] + 1, max_length):
|
| 63 |
+
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
|
| 64 |
+
scores.append(model(cur_input_ids).logits[:, -1])
|
| 65 |
+
sequences.append(scores[-1].argmax(dim=-1))
|
| 66 |
+
if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
|
| 67 |
+
break
|
| 68 |
+
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
|
| 69 |
+
scores = tuple(scores)
|
| 70 |
+
|
| 71 |
+
print("Without CUDA graph")
|
| 72 |
+
torch.cuda.synchronize()
|
| 73 |
+
start = time.time()
|
| 74 |
+
out = model.generate(
|
| 75 |
+
input_ids=input_ids,
|
| 76 |
+
max_length=max_length,
|
| 77 |
+
eos_token_id=eos_token_id,
|
| 78 |
+
return_dict_in_generate=True,
|
| 79 |
+
output_scores=True,
|
| 80 |
+
enable_timing=True,
|
| 81 |
+
)
|
| 82 |
+
torch.cuda.synchronize()
|
| 83 |
+
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
| 84 |
+
if verbose:
|
| 85 |
+
print(out.sequences)
|
| 86 |
+
print(tokenizer.batch_decode(out.sequences.tolist()))
|
| 87 |
+
if getattr(config, "use_flash_attn", False):
|
| 88 |
+
# Capture graph outside the timing loop
|
| 89 |
+
batch_size, seqlen_og = input_ids.shape
|
| 90 |
+
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
| 91 |
+
print("With CUDA graph")
|
| 92 |
+
torch.cuda.synchronize()
|
| 93 |
+
start = time.time()
|
| 94 |
+
out_cg = model.generate(
|
| 95 |
+
input_ids=input_ids,
|
| 96 |
+
max_length=max_length,
|
| 97 |
+
cg=True,
|
| 98 |
+
return_dict_in_generate=True,
|
| 99 |
+
output_scores=True,
|
| 100 |
+
enable_timing=True,
|
| 101 |
+
)
|
| 102 |
+
torch.cuda.synchronize()
|
| 103 |
+
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
| 104 |
+
if verbose:
|
| 105 |
+
print(out_cg.sequences)
|
| 106 |
+
print(tokenizer.batch_decode(out_cg.sequences.tolist()))
|
| 107 |
+
|
| 108 |
+
del model
|
| 109 |
+
|
| 110 |
+
model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
|
| 111 |
+
model_hf.eval()
|
| 112 |
+
print("HF fp16")
|
| 113 |
+
torch.cuda.synchronize()
|
| 114 |
+
start = time.time()
|
| 115 |
+
out_hf = model_hf.generate(
|
| 116 |
+
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
|
| 117 |
+
)
|
| 118 |
+
torch.cuda.synchronize()
|
| 119 |
+
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
| 120 |
+
del model_hf
|
| 121 |
+
|
| 122 |
+
model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
|
| 123 |
+
model_ref.eval()
|
| 124 |
+
print("HF fp32")
|
| 125 |
+
torch.cuda.synchronize()
|
| 126 |
+
start = time.time()
|
| 127 |
+
out_ref = model_ref.generate(
|
| 128 |
+
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
|
| 129 |
+
)
|
| 130 |
+
torch.cuda.synchronize()
|
| 131 |
+
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
| 132 |
+
del model_ref
|
| 133 |
+
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
|
| 134 |
+
|
| 135 |
+
if verbose:
|
| 136 |
+
print(
|
| 137 |
+
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
|
| 138 |
+
)
|
| 139 |
+
print(
|
| 140 |
+
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
|
| 141 |
+
)
|
| 142 |
+
print(
|
| 143 |
+
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
|
| 144 |
+
)
|
| 145 |
+
print(
|
| 146 |
+
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
assert torch.all(out.sequences == sequences)
|
| 150 |
+
assert torch.allclose(
|
| 151 |
+
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
|
| 152 |
+
)
|
| 153 |
+
assert torch.all(out.sequences == out_ref.sequences)
|
| 154 |
+
assert torch.all(out.sequences == out_hf.sequences)
|
| 155 |
+
|
| 156 |
+
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (
|
| 157 |
+
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
|
| 158 |
+
).abs().max().item()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def arg_parser():
|
| 162 |
+
parser = argparse.ArgumentParser(description='Inference benchmarking')
|
| 163 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 164 |
+
parser.add_argument('--model_index', type=int, default=5)
|
| 165 |
+
parser.add_argument('--seq_len', type=int, default=1024)
|
| 166 |
+
parser.add_argument('--index_size', type=int, default=8192)
|
| 167 |
+
parser.add_argument('--head_density', type=float, default=0.25)
|
| 168 |
+
parser.add_argument('--print_results', type=bool, default=False)
|
| 169 |
+
parser.add_argument('--iterations', type=int, default=1)
|
| 170 |
+
parser.add_argument('--check_results', type=bool, default=False)
|
| 171 |
+
parser.add_argument('--results_dir', type=str, default='results')
|
| 172 |
+
parser.add_argument('--gpu', type=int, default=0)
|
| 173 |
+
|
| 174 |
+
return parser.parse_args()
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
|
| 178 |
+
args = arg_parser()
|
| 179 |
+
model_name = OPT_MODELS[args.model_index-1]
|
| 180 |
+
# test_opt_generation(model_name)
|
| 181 |
+
|
| 182 |
+
print(f"\nMODEL: {model_name}\n")
|
| 183 |
+
verbose = False
|
| 184 |
+
dtype = torch.float16
|
| 185 |
+
device = "cuda"
|
| 186 |
+
rtol, atol = 3e-3, 3e-1
|
| 187 |
+
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
|
| 188 |
+
# Only prenorm supports residual_in_fp32
|
| 189 |
+
config.residual_in_fp32 = getattr(config, "prenorm", True)
|
| 190 |
+
config.use_flash_attn = True
|
| 191 |
+
config.fused_bias_fc = True
|
| 192 |
+
config.fused_mlp = True
|
| 193 |
+
config.fused_dropout_add_ln = True
|
| 194 |
+
|
| 195 |
+
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
| 196 |
+
model.eval()
|
| 197 |
+
|
| 198 |
+
torch.manual_seed(0)
|
| 199 |
+
# OPT tokenizer requires use_fast=False
|
| 200 |
+
# https://huggingface.co/docs/transformers/model_doc/opt
|
| 201 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
| 202 |
+
eos_token_id = tokenizer.eos_token_id
|
| 203 |
+
|
| 204 |
+
# input_ids = tokenizer("In a distant galaxy, a spaceship", return_tensors="pt").input_ids.to(
|
| 205 |
+
# device=device
|
| 206 |
+
# )
|
| 207 |
+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
| 208 |
+
|
| 209 |
+
tokens = tokenize_dataset(dataset, tokenizer)
|
| 210 |
+
input_ids = get_random_batch(tokens, args.batch_size, args.seq_len)
|
| 211 |
+
input_ids = input_ids.to(device=device)
|
| 212 |
+
max_length = args.seq_len + 20
|
| 213 |
+
|
| 214 |
+
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
|
| 215 |
+
# max_length = input_ids.shape[1] + 40
|
| 216 |
+
|
| 217 |
+
# warm up
|
| 218 |
+
_ = model.generate(
|
| 219 |
+
input_ids=input_ids,
|
| 220 |
+
max_length=max_length,
|
| 221 |
+
eos_token_id=eos_token_id,
|
| 222 |
+
return_dict_in_generate=True,
|
| 223 |
+
output_scores=True,
|
| 224 |
+
enable_timing=False,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
print("Without CUDA graph")
|
| 228 |
+
torch.cuda.synchronize()
|
| 229 |
+
start = time.time()
|
| 230 |
+
out = model.generate(
|
| 231 |
+
input_ids=input_ids,
|
| 232 |
+
max_length=max_length,
|
| 233 |
+
eos_token_id=eos_token_id,
|
| 234 |
+
return_dict_in_generate=True,
|
| 235 |
+
output_scores=True,
|
| 236 |
+
enable_timing=False,
|
| 237 |
+
)
|
| 238 |
+
torch.cuda.synchronize()
|
| 239 |
+
elapsed_time = (time.time() - start) * 1000
|
| 240 |
+
print(f"Prompt processing + decoding time: {elapsed_time:.0f} ms")
|
| 241 |
+
|
| 242 |
+
# Compute throughput and latency per token
|
| 243 |
+
num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
|
| 244 |
+
throughput = (args.batch_size * num_tokens_generated) / (elapsed_time / 1000)
|
| 245 |
+
latency_per_token = elapsed_time / num_tokens_generated # ms per token
|
| 246 |
+
|
| 247 |
+
# print(f"Number of tokens generated: {num_tokens_generated}")
|
| 248 |
+
print(f"Throughput: {throughput:.1f} tokens/second")
|
| 249 |
+
print(f"Latency per token: {latency_per_token:.1f} ms")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if args.print_results:
|
| 253 |
+
# print(out.sequences)
|
| 254 |
+
print(tokenizer.batch_decode(out.sequences.tolist()))
|
| 255 |
+
|
| 256 |
+
# ============================================================================= #
|
| 257 |
+
|
| 258 |
+
print("\n")
|
| 259 |
+
if getattr(config, "use_flash_attn", False):
|
| 260 |
+
# Capture graph outside the timing loop
|
| 261 |
+
batch_size, seqlen_og = input_ids.shape
|
| 262 |
+
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
| 263 |
+
print("With CUDA graph")
|
| 264 |
+
torch.cuda.synchronize()
|
| 265 |
+
start = time.time()
|
| 266 |
+
out_cg = model.generate(
|
| 267 |
+
input_ids=input_ids,
|
| 268 |
+
max_length=max_length,
|
| 269 |
+
cg=True,
|
| 270 |
+
return_dict_in_generate=True,
|
| 271 |
+
output_scores=True,
|
| 272 |
+
enable_timing=False,
|
| 273 |
+
)
|
| 274 |
+
torch.cuda.synchronize()
|
| 275 |
+
elapsed_time = (time.time() - start) * 1000
|
| 276 |
+
print(f"Prompt processing + decoding time: {elapsed_time:.0f} ms")
|
| 277 |
+
|
| 278 |
+
# Compute throughput and latency per token
|
| 279 |
+
num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
|
| 280 |
+
latency_per_token = elapsed_time / num_tokens_generated # ms per token
|
| 281 |
+
throughput = (args.batch_size * num_tokens_generated) / (elapsed_time / 1000)
|
| 282 |
+
|
| 283 |
+
# print(f"Number of tokens generated: {num_tokens_generated}")
|
| 284 |
+
print(f"Throughput: {throughput:.1f} tokens/second")
|
| 285 |
+
print(f"Latency per token: {latency_per_token:.1f} ms")
|
| 286 |
+
|
| 287 |
+
if args.print_results:
|
| 288 |
+
# print(out_cg.sequences)
|
| 289 |
+
print(tokenizer.batch_decode(out_cg.sequences.tolist()))
|
HybridTensor/benchmarks/generation/opt_sparse_gen_tp.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.models.opt import OPTConfig
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
from flash_attn.models.opt import opt_config_to_gpt2_config
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import argparse
|
| 8 |
+
from apex.transformer import parallel_state
|
| 9 |
+
|
| 10 |
+
from HybridTensor.utils.utils import arg_parser, _get_device
|
| 11 |
+
from HybridTensor.utils.activations import OPT_MODELS
|
| 12 |
+
from HybridTensor.models.opt import SparseConfig, build_sparse_opt
|
| 13 |
+
|
| 14 |
+
def update_router_config(model, num_layers, mlp_act_th, attn_topk, layer_config = None):
|
| 15 |
+
for i in range(num_layers):
|
| 16 |
+
model.transformer.layers[i].mlp_router.act_th = mlp_act_th
|
| 17 |
+
model.transformer.layers[i].mha_router.topk = attn_topk
|
| 18 |
+
|
| 19 |
+
def initialize_distributed_environment():
|
| 20 |
+
# Set environment variables for NCCL
|
| 21 |
+
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
| 22 |
+
os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0"
|
| 23 |
+
|
| 24 |
+
# Initialize the distributed process group
|
| 25 |
+
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
| 26 |
+
|
| 27 |
+
# Set the device based on the rank of the current process
|
| 28 |
+
device = f"cuda:{torch.distributed.get_rank()}"
|
| 29 |
+
world_size = torch.distributed.get_world_size()
|
| 30 |
+
|
| 31 |
+
# Set the current CUDA device to avoid operations being executed on the wrong GPU
|
| 32 |
+
torch.cuda.set_device(device)
|
| 33 |
+
|
| 34 |
+
# You can return device, world_size, and any other relevant information
|
| 35 |
+
return device, world_size
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def arg_parser():
|
| 39 |
+
parser = argparse.ArgumentParser(description='Inference benchmarking')
|
| 40 |
+
parser.add_argument('--batch_size', type=int, default=128)
|
| 41 |
+
parser.add_argument('--model_index', type=int, default=5)
|
| 42 |
+
parser.add_argument('--seq_len', type=int, default=28)
|
| 43 |
+
parser.add_argument('--index_size', type=int, default=8192)
|
| 44 |
+
parser.add_argument('--head_density', type=float, default=0.25)
|
| 45 |
+
parser.add_argument('--print_results', type=bool, default=True)
|
| 46 |
+
parser.add_argument('--iterations', type=int, default=100)
|
| 47 |
+
parser.add_argument('--check_results', type=bool, default=False)
|
| 48 |
+
parser.add_argument('--results_dir', type=str, default='results')
|
| 49 |
+
parser.add_argument('--gpu', type=int, default=0)
|
| 50 |
+
parser.add_argument('--mlp_ckpt_dir', type=str, default='<PATH_TO_MLP_ROUTER_CHECKPOINTS>')
|
| 51 |
+
parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model')
|
| 52 |
+
parser.add_argument('--attn_ckpt_dir', type=str, default='<PATH_TO_ATTENTION_CHECKPOINTS>')
|
| 53 |
+
|
| 54 |
+
return parser.parse_args()
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
|
| 58 |
+
args = arg_parser()
|
| 59 |
+
model_name = OPT_MODELS[args.model_index-1]
|
| 60 |
+
|
| 61 |
+
device, world_size = initialize_distributed_environment()
|
| 62 |
+
dtype = torch.float16
|
| 63 |
+
|
| 64 |
+
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
| 65 |
+
rank = parallel_state.get_tensor_model_parallel_rank()
|
| 66 |
+
process_group = parallel_state.get_tensor_model_parallel_group()
|
| 67 |
+
|
| 68 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 69 |
+
model = build_sparse_opt(model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype, process_group = process_group, world_size = world_size, rank = rank)
|
| 70 |
+
model.eval()
|
| 71 |
+
print("Model loaded with sparse routers")
|
| 72 |
+
|
| 73 |
+
mlp_act_th = 0.5
|
| 74 |
+
attn_topk = 0.5
|
| 75 |
+
|
| 76 |
+
update_router_config(model, model.config.n_layer, mlp_act_th, attn_topk)
|
| 77 |
+
print("Router config updated")
|
| 78 |
+
|
| 79 |
+
# print router configs from all layers
|
| 80 |
+
# for i in range(model.config.n_layer):
|
| 81 |
+
# print(f"Layer {i}: mlp_act_th = {model.transformer.layers[i].mlp_router.act_th}, attn_topk = {model.transformer.layers[i].mha_router.topk}")
|
| 82 |
+
|
| 83 |
+
input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "]
|
| 84 |
+
# input_texts = ["Hello, my dog is cute and", "Hello, my rat is cute and"]
|
| 85 |
+
|
| 86 |
+
tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(device)
|
| 87 |
+
input_ids=tokenized_inputs["input_ids"]
|
| 88 |
+
|
| 89 |
+
# input_ids = tokenizer("Hello, my dog is cute and", return_tensors="pt").input_ids.to(device=device)
|
| 90 |
+
max_length = args.seq_len
|
| 91 |
+
position_ids = None
|
| 92 |
+
eos_token_id = tokenizer.eos_token_id
|
| 93 |
+
num_layers = model.config.n_layer
|
| 94 |
+
|
| 95 |
+
# print all the model weights and check the accuracy
|
| 96 |
+
# if rank == 0:
|
| 97 |
+
# print(model.state_dict())
|
| 98 |
+
|
| 99 |
+
# out = model(input_ids)
|
| 100 |
+
# print(out)
|
| 101 |
+
|
| 102 |
+
out = model.generate(
|
| 103 |
+
input_ids=input_ids,
|
| 104 |
+
max_length=max_length,
|
| 105 |
+
eos_token_id=eos_token_id,
|
| 106 |
+
return_dict_in_generate=True,
|
| 107 |
+
output_scores=True,
|
| 108 |
+
enable_timing=True,
|
| 109 |
+
)
|
| 110 |
+
if rank == 0:
|
| 111 |
+
print(tokenizer.batch_decode(out.sequences.tolist()))
|
| 112 |
+
|
HybridTensor/benchmarks/generation/opt_sparse_generation.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
from HybridTensor.utils.utils import _get_device
|
| 5 |
+
from HybridTensor.utils.activations import OPT_MODELS
|
| 6 |
+
from HybridTensor.models.opt import SparseConfig, build_sparse_opt
|
| 7 |
+
from HybridTensor.benchmarks.generation.gen_util import tokenize_dataset, get_random_batch
|
| 8 |
+
from HybridTensor.utils.activations import build_mlp_topk_lookup
|
| 9 |
+
from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv
|
| 10 |
+
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
|
| 13 |
+
from transformers.models.opt import OPTConfig
|
| 14 |
+
from transformers import AutoTokenizer
|
| 15 |
+
from flash_attn.models.opt import opt_config_to_gpt2_config
|
| 16 |
+
from flash_attn.utils.generation import update_graph_cache
|
| 17 |
+
|
| 18 |
+
def arg_parser():
|
| 19 |
+
parser = argparse.ArgumentParser(description='Inference benchmarking')
|
| 20 |
+
parser.add_argument('--batch_size', type=int, default=16)
|
| 21 |
+
parser.add_argument('--model_index', type=int, default=5)
|
| 22 |
+
parser.add_argument('--seq_len', type=int, default=1024)
|
| 23 |
+
parser.add_argument('--index_size', type=int, default=8192)
|
| 24 |
+
parser.add_argument('--head_density', type=float, default=0.5)
|
| 25 |
+
parser.add_argument('--print_results', type=bool, default=True)
|
| 26 |
+
parser.add_argument('--iterations', type=int, default=1)
|
| 27 |
+
parser.add_argument('--check_results', type=bool, default=False)
|
| 28 |
+
parser.add_argument('--results_dir', type=str, default='results')
|
| 29 |
+
parser.add_argument('--gpu', type=int, default=0)
|
| 30 |
+
parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model')
|
| 31 |
+
parser.add_argument('--mlp_ckpt_dir', type=str, default='<PATH_TO_MLP_ROUTER_CHECKPOINTS>')
|
| 32 |
+
parser.add_argument('--attn_ckpt_dir', type=str, default='<PATH_TO_ATTENTION_CHECKPOINTS>')
|
| 33 |
+
parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router/opt-6.7b')
|
| 34 |
+
parser.add_argument('--delta', type=int, default=256, help='Delta value for MLP topk calculation')
|
| 35 |
+
parser.add_argument('--use_cuda_graph', type=bool, default=False, help='Use CUDA graph for inference')
|
| 36 |
+
|
| 37 |
+
return parser.parse_args()
|
| 38 |
+
|
| 39 |
+
def update_router_config(model, num_layers, mlp_topk_lookup, attn_topk):
|
| 40 |
+
for i in range(num_layers):
|
| 41 |
+
model.transformer.layers[i].mlp_topk = mlp_topk_lookup[i]
|
| 42 |
+
# model.transformer.layers[i].mlp_topk = 512
|
| 43 |
+
model.transformer.layers[i].mha_router.topk = attn_topk
|
| 44 |
+
|
| 45 |
+
# model.transformer.layers[i].skip_mlp_router = True
|
| 46 |
+
model.transformer.layers[0].mha_router.topk = 1.0 # dense attention in layer 0
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
args = arg_parser()
|
| 50 |
+
model_name = OPT_MODELS[args.model_index-1]
|
| 51 |
+
dtype = torch.float16
|
| 52 |
+
device= _get_device(args.gpu)
|
| 53 |
+
|
| 54 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 55 |
+
|
| 56 |
+
# args.mlp_ckpt_dir = None
|
| 57 |
+
# args.attn_ckpt_dir = None
|
| 58 |
+
|
| 59 |
+
model = build_sparse_opt(args, model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype)
|
| 60 |
+
model.eval()
|
| 61 |
+
print(model)
|
| 62 |
+
print("Model loaded with sparse routers")
|
| 63 |
+
|
| 64 |
+
# mlp_topk_lookup = build_mlp_topk_lookup("results/mlp_results/batch_activations/opt-6.7b", args.batch_size, args.delta)
|
| 65 |
+
mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, args.batch_size)
|
| 66 |
+
print("MLP topk values updated: ", mlp_topk_lookup)
|
| 67 |
+
update_router_config(model, model.config.n_layer, mlp_topk_lookup, args.attn_topk) # this sets the router config for all layers using a single config
|
| 68 |
+
# update_router_config(model, model.config.n_layer, 2048, args.attn_topk)
|
| 69 |
+
print("Router config updated \n")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
max_length = args.seq_len + 20
|
| 73 |
+
batch_size = args.batch_size
|
| 74 |
+
|
| 75 |
+
# input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "]
|
| 76 |
+
# input_texts = ["In a distant galaxy, a spaceship"]
|
| 77 |
+
# tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=False).to(device)
|
| 78 |
+
# input_ids=tokenized_inputs["input_ids"]
|
| 79 |
+
|
| 80 |
+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
| 81 |
+
tokens = tokenize_dataset(dataset, tokenizer)
|
| 82 |
+
input_ids = get_random_batch(tokens, args.batch_size, args.seq_len).to(device)
|
| 83 |
+
|
| 84 |
+
print("Input ids generated, starting inference")
|
| 85 |
+
|
| 86 |
+
# input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(device)
|
| 87 |
+
position_ids = None
|
| 88 |
+
eos_token_id = tokenizer.eos_token_id
|
| 89 |
+
|
| 90 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 91 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 92 |
+
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
# warm up
|
| 95 |
+
_ = model.generate(
|
| 96 |
+
input_ids=input_ids,
|
| 97 |
+
max_length=max_length,
|
| 98 |
+
eos_token_id=eos_token_id,
|
| 99 |
+
return_dict_in_generate=True,
|
| 100 |
+
output_scores=True,
|
| 101 |
+
enable_timing=False,
|
| 102 |
+
cg=False,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
print("Warm up done")
|
| 106 |
+
|
| 107 |
+
start_event.record()
|
| 108 |
+
for i in range(args.iterations):
|
| 109 |
+
out = model.generate(
|
| 110 |
+
input_ids=input_ids,
|
| 111 |
+
max_length=max_length,
|
| 112 |
+
eos_token_id=eos_token_id,
|
| 113 |
+
return_dict_in_generate=True,
|
| 114 |
+
output_scores=True,
|
| 115 |
+
enable_timing=False,
|
| 116 |
+
cg=False,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
end_event.record()
|
| 120 |
+
|
| 121 |
+
torch.cuda.synchronize()
|
| 122 |
+
print("Without CUDA graph")
|
| 123 |
+
elapsed_time = start_event.elapsed_time(end_event) / args.iterations
|
| 124 |
+
print(f"Average time per genearation : {elapsed_time:.1f} ms")
|
| 125 |
+
|
| 126 |
+
# Compute throughput and latency per token
|
| 127 |
+
num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
|
| 128 |
+
throughput = batch_size * num_tokens_generated / (elapsed_time / 1000) # tokens per second
|
| 129 |
+
latency_per_token = elapsed_time / num_tokens_generated # ms per token
|
| 130 |
+
|
| 131 |
+
print(f"Number of tokens generated: {num_tokens_generated}")
|
| 132 |
+
print(f"Throughput: {throughput:.1f} tokens/second")
|
| 133 |
+
print(f"Latency per token: {latency_per_token:.1f} ms")
|
| 134 |
+
|
| 135 |
+
# print(tokenizer.batch_decode(out.sequences.tolist()))
|
| 136 |
+
print("\n")
|
| 137 |
+
|
| 138 |
+
# print only the new tokens generated
|
| 139 |
+
print("New tokens generated:")
|
| 140 |
+
print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist()))
|
| 141 |
+
|
| 142 |
+
# ====================== With CUDA graph ======================
|
| 143 |
+
if args.use_cuda_graph:
|
| 144 |
+
batch_size, seqlen_og = input_ids.shape
|
| 145 |
+
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
| 146 |
+
print("With CUDA graph")
|
| 147 |
+
torch.cuda.synchronize()
|
| 148 |
+
|
| 149 |
+
start_event.record()
|
| 150 |
+
|
| 151 |
+
for i in range(args.iterations):
|
| 152 |
+
out = model.generate(
|
| 153 |
+
input_ids=input_ids,
|
| 154 |
+
max_length=max_length,
|
| 155 |
+
cg=True,
|
| 156 |
+
eos_token_id=eos_token_id,
|
| 157 |
+
return_dict_in_generate=True,
|
| 158 |
+
output_scores=True,
|
| 159 |
+
enable_timing=False,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
end_event.record()
|
| 163 |
+
|
| 164 |
+
torch.cuda.synchronize()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
elapsed_time = start_event.elapsed_time(end_event) / args.iterations
|
| 168 |
+
print(f"Average time per genearation : {elapsed_time:.1f} ms")
|
| 169 |
+
|
| 170 |
+
# Compute throughput and latency per token
|
| 171 |
+
num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
|
| 172 |
+
throughput = batch_size * num_tokens_generated / (elapsed_time / 1000) # tokens per second
|
| 173 |
+
latency_per_token = elapsed_time / num_tokens_generated # ms per token
|
| 174 |
+
|
| 175 |
+
print(f"Number of tokens generated: {num_tokens_generated}")
|
| 176 |
+
print(f"Throughput: {throughput:.1f} tokens/second")
|
| 177 |
+
print(f"Latency per token: {latency_per_token:.1f} ms")
|
| 178 |
+
|
| 179 |
+
# print(tokenizer.batch_decode(out.sequences.tolist()))
|
| 180 |
+
print("New tokens generated:")
|
| 181 |
+
print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist()))
|
| 182 |
+
|
HybridTensor/benchmarks/model_eval.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
import csv
|
| 8 |
+
|
| 9 |
+
# from hf_models.opt.modeling_opt_routers import (
|
| 10 |
+
# SparseOPTForCausalLM,
|
| 11 |
+
# create_hf_mha_router_state_dict,
|
| 12 |
+
# create_hf_mlp_router_state_dict
|
| 13 |
+
# )
|
| 14 |
+
|
| 15 |
+
from hf_models.opt.modeling_opt_routers_topk import (
|
| 16 |
+
SparseOPTForCausalLM,
|
| 17 |
+
create_hf_mha_router_state_dict,
|
| 18 |
+
create_hf_mlp_router_state_dict
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from hf_models.llama.modeling_sparse_llama_routers import (
|
| 22 |
+
SparseLlamaForCausalLM,
|
| 23 |
+
create_hf_attn_router_state_dict
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM as SparseOPTTopKAttn
|
| 27 |
+
from hf_models.llama.modeling_sparse_llama_mha_topk import SparseLlamaForCausalLM as SparseLlamaTopKAttn
|
| 28 |
+
from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import _update_model_attn_thresholds
|
| 29 |
+
from HybridTensor.benchmarks.model_perplexity import compute_attn_layer_sparsity, compute_average_activation
|
| 30 |
+
from HybridTensor.utils.activations import ActivationThresholds, build_mlp_topk_lookup, _update_hf_mlp_topk, CONFIGS, MODELS
|
| 31 |
+
from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv
|
| 32 |
+
from HybridTensor.utils.utils import extract_model_name
|
| 33 |
+
|
| 34 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 35 |
+
|
| 36 |
+
from lm_eval.models.huggingface import HFLM
|
| 37 |
+
from lm_eval.tasks import TaskManager
|
| 38 |
+
import lm_eval
|
| 39 |
+
|
| 40 |
+
import pandas as pd
|
| 41 |
+
from tabulate import tabulate
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
import logging
|
| 45 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
| 46 |
+
|
| 47 |
+
import warnings
|
| 48 |
+
warnings.simplefilter(action='ignore', category=FutureWarning)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
from huggingface_hub import login
|
| 52 |
+
|
| 53 |
+
def read_and_print_results(filepath='results.csv'):
|
| 54 |
+
"""
|
| 55 |
+
Reads the CSV file containing evaluation results and prints them in a formatted table.
|
| 56 |
+
"""
|
| 57 |
+
if not os.path.exists(filepath):
|
| 58 |
+
print(f"File '{filepath}' not found.")
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
df = pd.read_csv(filepath)
|
| 62 |
+
print(tabulate(df, headers='keys', tablefmt='psql', showindex=False))
|
| 63 |
+
|
| 64 |
+
def save_results_to_csv(results, attn_topk, filepath='eval_results.csv'):
|
| 65 |
+
"""
|
| 66 |
+
Extracts benchmark accuracies from results and saves them along with the attn_topk config.
|
| 67 |
+
|
| 68 |
+
Parameters:
|
| 69 |
+
results: dict, evaluation results with structure results['results'][<benchmark>]['acc,none']
|
| 70 |
+
attn_topk: float, the attention top-k value used for this run
|
| 71 |
+
filepath: str, CSV file to write to (appends if it exists)
|
| 72 |
+
"""
|
| 73 |
+
# Build a dictionary row with attn_topk and each benchmark's accuracy
|
| 74 |
+
row = {'attn_topk': attn_topk}
|
| 75 |
+
for benchmark, data in results['results'].items():
|
| 76 |
+
# Default to None if the key is missing
|
| 77 |
+
row[benchmark] = data.get('acc,none', None)
|
| 78 |
+
|
| 79 |
+
# Check if file exists to decide on writing header
|
| 80 |
+
file_exists = os.path.isfile(filepath)
|
| 81 |
+
with open(filepath, 'a', newline='') as csvfile:
|
| 82 |
+
writer = csv.DictWriter(csvfile, fieldnames=row.keys())
|
| 83 |
+
if not file_exists:
|
| 84 |
+
writer.writeheader()
|
| 85 |
+
writer.writerow(row)
|
| 86 |
+
|
| 87 |
+
def _update_model_attn_sparsity(model, attn_th):
|
| 88 |
+
num_layers = model.config.num_hidden_layers
|
| 89 |
+
|
| 90 |
+
# Use the 'decoder' attribute if it exists; otherwise use model.model.layers
|
| 91 |
+
layers = model.model.decoder.layers if hasattr(model.model, 'decoder') else model.model.layers
|
| 92 |
+
attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=attn_th)
|
| 93 |
+
|
| 94 |
+
for i in range(num_layers):
|
| 95 |
+
layers[i].self_attn.sp_threshold = attn_sparsity_map[i]
|
| 96 |
+
|
| 97 |
+
average_act = compute_average_activation(attn_sparsity_map)
|
| 98 |
+
print(f"Attention sparsity {attn_th}: {attn_sparsity_map}")
|
| 99 |
+
print(f"Average activation: {average_act:.2f}")
|
| 100 |
+
|
| 101 |
+
return model
|
| 102 |
+
|
| 103 |
+
def _evaluate_model(model, tokenizer, benchmarks: list, device: str, batch_size: int = 8):
|
| 104 |
+
logging.info("Evaluating on benchmarks: %s", benchmarks)
|
| 105 |
+
lm_obj = HFLM(
|
| 106 |
+
pretrained=model,
|
| 107 |
+
tokenizer=tokenizer,
|
| 108 |
+
device=device,
|
| 109 |
+
batch_size=batch_size
|
| 110 |
+
)
|
| 111 |
+
task_manager = TaskManager()
|
| 112 |
+
num_fewshot = 5
|
| 113 |
+
print(f"Number of fewshot examples: {num_fewshot}")
|
| 114 |
+
results = lm_eval.simple_evaluate(
|
| 115 |
+
model=lm_obj,
|
| 116 |
+
tasks=benchmarks,
|
| 117 |
+
num_fewshot=num_fewshot, # change this
|
| 118 |
+
task_manager=task_manager
|
| 119 |
+
)
|
| 120 |
+
logging.info("Evaluation complete.")
|
| 121 |
+
for benchmark, benchmark_results in results['results'].items():
|
| 122 |
+
logging.info("Results for %s: %s", benchmark.upper(), benchmark_results)
|
| 123 |
+
return results
|
| 124 |
+
|
| 125 |
+
def _load_model(model_name, num_layers, device, args):
|
| 126 |
+
if args.mode == 'sparse':
|
| 127 |
+
logging.info("Loading sparse model...")
|
| 128 |
+
sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th= args.attn_topk, mlp_th=args.mlp_topk)
|
| 129 |
+
|
| 130 |
+
if args.model_index <=8:
|
| 131 |
+
# OPT models
|
| 132 |
+
model = SparseOPTForCausalLM.from_pretrained(
|
| 133 |
+
model_name,
|
| 134 |
+
device_map=device,
|
| 135 |
+
torch_dtype=torch.float16,
|
| 136 |
+
sp_thresholds=sp_thresholds.activation_threshold,
|
| 137 |
+
mlp_thresholds=sp_thresholds.mlp_threshold,
|
| 138 |
+
attn_implementation="flash_attention_2"
|
| 139 |
+
)
|
| 140 |
+
logging.info("Loading router states...")
|
| 141 |
+
mlp_router_state = create_hf_mlp_router_state_dict(args.mlp_ckpt_dir)
|
| 142 |
+
mha_router_state = create_hf_mha_router_state_dict(args.attn_ckpt_dir)
|
| 143 |
+
model_state = model.state_dict()
|
| 144 |
+
model_state.update(mlp_router_state)
|
| 145 |
+
model_state.update(mha_router_state)
|
| 146 |
+
model.load_state_dict(model_state)
|
| 147 |
+
logging.info("Sparse model loaded with routers!")
|
| 148 |
+
|
| 149 |
+
# load topk values for mlp and attn here
|
| 150 |
+
# mlp_topk_lookup = build_mlp_topk_lookup(args.batch_stats_dir, args.batch_size, args.delta)
|
| 151 |
+
# mlp_topk_lookup = build_mlp_topk_lookup(args.batch_stats_dir, 1, args.delta)
|
| 152 |
+
mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, 1)
|
| 153 |
+
|
| 154 |
+
_update_hf_mlp_topk(model, mlp_topk_lookup)
|
| 155 |
+
# print("MLP topk values updated.")
|
| 156 |
+
# print("MLP topk values: ", mlp_topk_lookup)
|
| 157 |
+
logging.info("Using MLP topk values: %s", mlp_topk_lookup)
|
| 158 |
+
|
| 159 |
+
# print("Using delta value: ", args.delta)
|
| 160 |
+
|
| 161 |
+
# the first layer should use dense attention
|
| 162 |
+
model.model.decoder.layers[0].self_attn.sp_threshold = 1.0
|
| 163 |
+
else:
|
| 164 |
+
# Llama models
|
| 165 |
+
|
| 166 |
+
if not args.static_thresholds:
|
| 167 |
+
attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=args.attn_topk)
|
| 168 |
+
sp_thresholds.load_thresholds(attn_sparsity_map)
|
| 169 |
+
average_act = compute_average_activation(attn_sparsity_map)
|
| 170 |
+
print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}")
|
| 171 |
+
print(f"Average activation: {average_act:.2f}")
|
| 172 |
+
|
| 173 |
+
model = SparseLlamaForCausalLM.from_pretrained(model_name,
|
| 174 |
+
device_map = device,
|
| 175 |
+
torch_dtype=torch.float16,
|
| 176 |
+
sp_thresholds = sp_thresholds.activation_threshold,
|
| 177 |
+
attn_implementation="flash_attention_2")
|
| 178 |
+
logging.info("Loading router states...")
|
| 179 |
+
model_state = model.state_dict()
|
| 180 |
+
attn_router_states = create_hf_attn_router_state_dict(args.attn_ckpt_dir)
|
| 181 |
+
model_state.update(attn_router_states)
|
| 182 |
+
model.load_state_dict(model_state)
|
| 183 |
+
logging.info("Sparse model loaded with routers!")
|
| 184 |
+
|
| 185 |
+
# the first layer should use dense attetnion
|
| 186 |
+
_update_model_attn_thresholds(model, args.attn_topk)
|
| 187 |
+
|
| 188 |
+
# load topk values for mha here
|
| 189 |
+
# TODO: create a function to update the topk values for mha
|
| 190 |
+
|
| 191 |
+
elif args.mode == 'sparse_attn':
|
| 192 |
+
logging.info("Loading model with sparse attention")
|
| 193 |
+
sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=args.attn_topk)
|
| 194 |
+
|
| 195 |
+
if not args.static_thresholds:
|
| 196 |
+
attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=args.attn_topk)
|
| 197 |
+
sp_thresholds.load_thresholds(attn_sparsity_map)
|
| 198 |
+
average_act = compute_average_activation(attn_sparsity_map)
|
| 199 |
+
print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}")
|
| 200 |
+
print(f"Average activation: {average_act:.2f}")
|
| 201 |
+
|
| 202 |
+
if args.model_index <= 8:
|
| 203 |
+
# opt models
|
| 204 |
+
model = SparseOPTTopKAttn.from_pretrained(model_name, device_map = device, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
|
| 205 |
+
else:
|
| 206 |
+
# llama models
|
| 207 |
+
model = SparseLlamaTopKAttn.from_pretrained(model_name, device_map = device, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
|
| 208 |
+
else:
|
| 209 |
+
logging.info("Loading dense model...")
|
| 210 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.float16)
|
| 211 |
+
return model
|
| 212 |
+
|
| 213 |
+
def arg_parser():
|
| 214 |
+
parser = argparse.ArgumentParser(description='Inference benchmarking')
|
| 215 |
+
parser.add_argument('--batch_size', type=int, default=8)
|
| 216 |
+
parser.add_argument('--model_index', type=int, default=5)
|
| 217 |
+
parser.add_argument('--print_results', type=bool, default=True)
|
| 218 |
+
parser.add_argument('--results_dir', type=str, default='results/eval')
|
| 219 |
+
parser.add_argument('--device', type=int, default=100)
|
| 220 |
+
parser.add_argument('--mode', type=str, default='sparse', choices=['sparse', 'dense', 'sparse_attn'])
|
| 221 |
+
parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model')
|
| 222 |
+
parser.add_argument('--mlp_topk', type=int, default=2048, help='MLP topk for sparse model')
|
| 223 |
+
parser.add_argument('--delta', type=int, default=128, help='Delta value for MLP topk calculation')
|
| 224 |
+
parser.add_argument('--mlp_ckpt_dir', type=str, default='<PATH_TO_MLP_ROUTER_CHECKPOINTS>')
|
| 225 |
+
parser.add_argument('--attn_ckpt_dir', type=str, default='<PATH_TO_ATTENTION_CHECKPOINTS>')
|
| 226 |
+
parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router')
|
| 227 |
+
parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds')
|
| 228 |
+
parser.add_argument('--benchmark', type=str, default='all', help='Options: all, or a single benchmark name')
|
| 229 |
+
parser.add_argument('--note', type=str, default='', help='Note to add to the results filename')
|
| 230 |
+
parser.add_argument('--static_thresholds', type=bool, default=True, help='Use static thresholds for attention layers')
|
| 231 |
+
return parser.parse_args()
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
args = arg_parser()
|
| 235 |
+
|
| 236 |
+
login_token = None # insert your token here
|
| 237 |
+
assert login_token is not None, "Please provide a valid Hugging Face token."
|
| 238 |
+
login(token=login_token)
|
| 239 |
+
|
| 240 |
+
# Setup logging
|
| 241 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 242 |
+
|
| 243 |
+
model_name = MODELS[args.model_index - 1]
|
| 244 |
+
# print(f"Evaluating Model: {model_name}")
|
| 245 |
+
logging.info("Evaluating Model: %s", model_name)
|
| 246 |
+
logging.info("Mode: %s", args.mode)
|
| 247 |
+
|
| 248 |
+
num_layers = CONFIGS[model_name]['num_layer']
|
| 249 |
+
device = 'auto' if args.device == 100 else f'cuda:{args.device}'
|
| 250 |
+
|
| 251 |
+
# Load tokenizer and model
|
| 252 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
| 253 |
+
model = _load_model(model_name, num_layers, device, args)
|
| 254 |
+
model.eval()
|
| 255 |
+
|
| 256 |
+
# Determine benchmarks to evaluate
|
| 257 |
+
if args.benchmark == 'all':
|
| 258 |
+
benchmarks = ["piqa", "winogrande", "copa", "rte", "openbookqa", "arc_easy", "arc_challenge", "mmlu", "hellaswag"]
|
| 259 |
+
else:
|
| 260 |
+
benchmarks = [args.benchmark]
|
| 261 |
+
|
| 262 |
+
model_name_clean = extract_model_name(model_name)
|
| 263 |
+
|
| 264 |
+
if args.data_collection:
|
| 265 |
+
# make sure the model is not dense
|
| 266 |
+
assert args.mode != 'dense', "Data collection is only available for sparse models"
|
| 267 |
+
logging.info("Data collection mode enabled.")
|
| 268 |
+
if args.mode == 'sparse':
|
| 269 |
+
filepath = f"{args.results_dir}/eval_results_{model_name_clean}_sparse_sweep_dpsd.csv"
|
| 270 |
+
else: # sparse_attn
|
| 271 |
+
filepath = f"{args.results_dir}/eval_results_{model_name_clean}_attn_sweep_dpsd.csv"
|
| 272 |
+
|
| 273 |
+
if args.note != '':
|
| 274 |
+
filepath = filepath.replace('.csv', f"_{args.note}.csv")
|
| 275 |
+
# attn_topk_values = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] # MHA
|
| 276 |
+
attn_topk_values = [0.9, 0.8, 0.7, 0.6, 0.4, 0.3, 0.2, 0.1]
|
| 277 |
+
# attn_topk_values = [7/8, 6/8, 5/8, 4/8, 3/8, 2/8, 1/8] # GQA
|
| 278 |
+
for attn_topk in attn_topk_values:
|
| 279 |
+
logging.info("Evaluating with attention top-k value: %s", attn_topk)
|
| 280 |
+
if args.static_thresholds:
|
| 281 |
+
_update_model_attn_thresholds(model, attn_topk, mode=args.mode)
|
| 282 |
+
else:
|
| 283 |
+
_update_model_attn_sparsity(model, attn_topk)
|
| 284 |
+
|
| 285 |
+
results = _evaluate_model(
|
| 286 |
+
model=model,
|
| 287 |
+
tokenizer=tokenizer,
|
| 288 |
+
benchmarks=benchmarks,
|
| 289 |
+
device=device,
|
| 290 |
+
batch_size=args.batch_size
|
| 291 |
+
)
|
| 292 |
+
save_results_to_csv(results, attn_topk, filepath = filepath)
|
| 293 |
+
else:
|
| 294 |
+
logging.info("Evaluating with attention top-k value: %s", args.attn_topk)
|
| 295 |
+
if args.mode == 'dense':
|
| 296 |
+
filepath = f"{args.results_dir}/eval_results_{model_name_clean}_dense.csv"
|
| 297 |
+
elif args.mode == 'sparse_attn':
|
| 298 |
+
filepath = f"{args.results_dir}/eval_results_{model_name_clean}_sparse_attn_{args.attn_topk}_dpsd.csv"
|
| 299 |
+
else:
|
| 300 |
+
filepath = f"{args.results_dir}/eval_results_{model_name_clean}_test_attn_{args.attn_topk}_dpsd.csv"
|
| 301 |
+
if args.note != '':
|
| 302 |
+
filepath = filepath.replace('.csv', f"_{args.note}.csv")
|
| 303 |
+
results = _evaluate_model(
|
| 304 |
+
model=model,
|
| 305 |
+
tokenizer=tokenizer,
|
| 306 |
+
benchmarks=benchmarks,
|
| 307 |
+
device=device,
|
| 308 |
+
batch_size=args.batch_size
|
| 309 |
+
)
|
| 310 |
+
save_results_to_csv(results, args.attn_topk, filepath = filepath)
|
| 311 |
+
|
| 312 |
+
if args.print_results:
|
| 313 |
+
read_and_print_results(filepath=filepath)
|
HybridTensor/benchmarks/model_perplexity.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python -m HybridTensor.benchmarks.model_perplexity --model_index 14 --batch_size 4 --max_length 512 --attn_th 1 --static_thresholds True
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
| 7 |
+
|
| 8 |
+
from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM as SparseOPTTopkAttn
|
| 9 |
+
from hf_models.llama.modeling_sparse_llama_mha_topk import SparseLlamaForCausalLM as SparseLlamaTopKAttn
|
| 10 |
+
from HybridTensor.utils.activations import ActivationThresholds, identify_model_type, MODELS, CONFIGS
|
| 11 |
+
from HybridTensor.utils.utils import extract_model_name, compute_perplexity
|
| 12 |
+
import argparse
|
| 13 |
+
from datasets import load_dataset
|
| 14 |
+
import json
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import (_update_model_attn_thresholds,
|
| 21 |
+
build_data_loader,
|
| 22 |
+
compute_sparse_perplexity,
|
| 23 |
+
compute_perplexity_data_collection,
|
| 24 |
+
display_model_menu,
|
| 25 |
+
_interactive_mode,
|
| 26 |
+
arg_parser,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
results_dir = "results/activations"
|
| 31 |
+
|
| 32 |
+
def compute_attn_layer_sparsity(model_name, min_th, critical_th, attn_sparsity):
|
| 33 |
+
# Get model configuration
|
| 34 |
+
# model_name = MODELS[model_index - 1]
|
| 35 |
+
model_config = CONFIGS[model_name]
|
| 36 |
+
num_layers = model_config['num_layer']
|
| 37 |
+
|
| 38 |
+
# Load the importance scores from the file specified in the configuration
|
| 39 |
+
file_path = model_config['layer_imp']
|
| 40 |
+
with open(file_path, 'r') as f:
|
| 41 |
+
attn_layer_imp = json.load(f)
|
| 42 |
+
layer_importance = attn_layer_imp['importance_scores']
|
| 43 |
+
|
| 44 |
+
# Classify layers as critical or sparse
|
| 45 |
+
critical_layers = [i for i, imp in enumerate(layer_importance) if imp >= critical_th]
|
| 46 |
+
sparse_layers = [i for i, imp in enumerate(layer_importance) if imp < critical_th]
|
| 47 |
+
|
| 48 |
+
# Calculate total sparse importance and the attention value
|
| 49 |
+
sum_sparse_importance = sum(layer_importance[i] for i in sparse_layers)
|
| 50 |
+
attn_val = attn_sparsity * len(sparse_layers)
|
| 51 |
+
|
| 52 |
+
# Compute the sparsity map per layer
|
| 53 |
+
layer_sparsity_map = {}
|
| 54 |
+
for layer_idx in range(num_layers):
|
| 55 |
+
if layer_idx in critical_layers:
|
| 56 |
+
layer_sparsity_map[layer_idx] = 1.0 # Fully dense for critical layers
|
| 57 |
+
else:
|
| 58 |
+
if sum_sparse_importance > 0:
|
| 59 |
+
raw_fraction = (layer_importance[layer_idx] / sum_sparse_importance) * attn_val
|
| 60 |
+
else:
|
| 61 |
+
raw_fraction = attn_sparsity
|
| 62 |
+
# Clamp the fraction between min_th and 1.0
|
| 63 |
+
fraction = max(raw_fraction, min_th)
|
| 64 |
+
fraction = min(fraction, 1.0)
|
| 65 |
+
layer_sparsity_map[layer_idx] = fraction
|
| 66 |
+
|
| 67 |
+
return layer_sparsity_map
|
| 68 |
+
|
| 69 |
+
def compute_average_activation(layer_sparsity_map):
|
| 70 |
+
"""
|
| 71 |
+
Computes the average activation for each layer based on the sparsity map.
|
| 72 |
+
"""
|
| 73 |
+
total_activation = 0.0
|
| 74 |
+
for layer_idx, fraction in layer_sparsity_map.items():
|
| 75 |
+
total_activation += fraction
|
| 76 |
+
|
| 77 |
+
average_activation = total_activation / len(layer_sparsity_map)
|
| 78 |
+
return average_activation
|
| 79 |
+
|
| 80 |
+
def compute_sparse_perplexity(model_name='facebook/opt-125m',
|
| 81 |
+
dataset_name='wikitext',
|
| 82 |
+
dataset_config='wikitext-2-raw-v1',
|
| 83 |
+
batch_size=8,
|
| 84 |
+
max_length=512,
|
| 85 |
+
attn_th=0.0,
|
| 86 |
+
static_thresholds=True,
|
| 87 |
+
device_map="cuda:0"):
|
| 88 |
+
# Set device
|
| 89 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 90 |
+
print(f'Using device: {device}')
|
| 91 |
+
|
| 92 |
+
# load the activation thresholds
|
| 93 |
+
num_layers = CONFIGS[model_name]['num_layer']
|
| 94 |
+
sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=attn_th)
|
| 95 |
+
|
| 96 |
+
print(f"Static attention activations: {sp_thresholds.activation_threshold}")
|
| 97 |
+
if not static_thresholds:
|
| 98 |
+
# act_threshold_filepath = CONFIGS[model_name]['sp_config']
|
| 99 |
+
attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=attn_th)
|
| 100 |
+
sp_thresholds.load_thresholds(attn_sparsity_map)
|
| 101 |
+
average_act = compute_average_activation(attn_sparsity_map)
|
| 102 |
+
print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}")
|
| 103 |
+
print(f"Average activation: {average_act:.2f}")
|
| 104 |
+
|
| 105 |
+
# Load tokenizer and model
|
| 106 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
| 107 |
+
model_type = identify_model_type(model_name)
|
| 108 |
+
if model_type == 'OPT':
|
| 109 |
+
print(f"Loading OPT model: {model_name}")
|
| 110 |
+
model = SparseOPTTopkAttn.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
|
| 111 |
+
elif model_type == 'Llama':
|
| 112 |
+
print(f"Loading Llama model: {model_name}")
|
| 113 |
+
model = SparseLlamaTopKAttn.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
|
| 114 |
+
model.eval()
|
| 115 |
+
|
| 116 |
+
# # Load dataset
|
| 117 |
+
dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length)
|
| 118 |
+
perplexity = compute_perplexity(model, dataloader, device)
|
| 119 |
+
return perplexity
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def arg_parser():
|
| 123 |
+
parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation')
|
| 124 |
+
parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate')
|
| 125 |
+
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation')
|
| 126 |
+
parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length')
|
| 127 |
+
parser.add_argument('--attn_th', type=float, default=0.0, help='Activation threshold for attention layers')
|
| 128 |
+
parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds')
|
| 129 |
+
parser.add_argument('--device_map', type=str, default='auto', help='Device to use for evaluation')
|
| 130 |
+
parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection')
|
| 131 |
+
parser.add_argument('--static_thresholds', type=bool, default=False, help='Use static thresholds for attention layers')
|
| 132 |
+
|
| 133 |
+
return parser.parse_args()
|
| 134 |
+
|
| 135 |
+
def main():
|
| 136 |
+
"""
|
| 137 |
+
Main function to execute the perplexity computation with user-selected OPT model.
|
| 138 |
+
"""
|
| 139 |
+
print("=== OPT Models Perplexity Evaluation ===\n")
|
| 140 |
+
args = arg_parser()
|
| 141 |
+
|
| 142 |
+
if args.interactive:
|
| 143 |
+
selected_model, batch_size, max_length, data_collection, device_map, attn_th = _interactive_mode()
|
| 144 |
+
|
| 145 |
+
else:
|
| 146 |
+
selected_model, batch_size, max_length, data_collection, device_map, attn_th = MODELS[args.model_index-1], args.batch_size, args.max_length, args.data_collection, args.device_map, args.attn_th
|
| 147 |
+
print(f"Selected model: {selected_model}, batch size: {batch_size}, max length: {max_length}, attn_th: {attn_th}, data_collection: {data_collection}, device: {device_map}")
|
| 148 |
+
|
| 149 |
+
if data_collection:
|
| 150 |
+
print("\nStarting data collection...\n")
|
| 151 |
+
compute_perplexity_data_collection(model_name=selected_model, batch_size=batch_size, max_length=max_length, device_map=device_map)
|
| 152 |
+
print("\nData collection complete.\n")
|
| 153 |
+
|
| 154 |
+
else:
|
| 155 |
+
print("\nStarting perplexity computation...\n")
|
| 156 |
+
perplexity = compute_sparse_perplexity(model_name=selected_model, batch_size=batch_size, max_length=max_length,
|
| 157 |
+
attn_th=attn_th,
|
| 158 |
+
device_map=device_map,
|
| 159 |
+
static_thresholds=args.static_thresholds)
|
| 160 |
+
print(f"\n=== Perplexity Results ===")
|
| 161 |
+
print(f"Model: {selected_model}")
|
| 162 |
+
print(f"Perplexity: {perplexity:.2f}\n")
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
main()
|
HybridTensor/benchmarks/opt_attn_sparse_topk_perplexity.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
| 6 |
+
# from hf_models.opt.modeling_sparse_opt import SparseOPTForCausalLM
|
| 7 |
+
from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM
|
| 8 |
+
from HybridTensor.utils.activations import ActivationThresholds, MODELS, CONFIGS
|
| 9 |
+
from HybridTensor.utils.utils import extract_model_name, compute_perplexity
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
import pandas as pd
|
| 17 |
+
|
| 18 |
+
results_dir = "results/activations"
|
| 19 |
+
|
| 20 |
+
def _update_model_attn_thresholds(model, attn_th, mode='sparse'):
|
| 21 |
+
num_layers = model.config.num_hidden_layers
|
| 22 |
+
|
| 23 |
+
# Use the 'decoder' attribute if it exists; otherwise use model.model.layers
|
| 24 |
+
layers = model.model.decoder.layers if hasattr(model.model, 'decoder') else model.model.layers
|
| 25 |
+
|
| 26 |
+
for i in range(num_layers):
|
| 27 |
+
layers[i].self_attn.sp_threshold = attn_th
|
| 28 |
+
|
| 29 |
+
# For non-sparse attention, layer 0 should use a threshold of 1.0
|
| 30 |
+
# if mode != 'sparse_attn':
|
| 31 |
+
# layers[0].self_attn.sp_threshold = 1.0
|
| 32 |
+
layers[0].self_attn.sp_threshold = 1.0
|
| 33 |
+
|
| 34 |
+
return model
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length, split='test'):
|
| 38 |
+
"""
|
| 39 |
+
Build a DataLoader for the specified dataset.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
- model_name (str): The Hugging Face identifier of the model.
|
| 43 |
+
- dataset_name (str): The name of the dataset.
|
| 44 |
+
- dataset_config (str): The configuration of the dataset.
|
| 45 |
+
- batch_size (int): The batch size for the DataLoader.
|
| 46 |
+
- max_length (int): The maximum sequence length.
|
| 47 |
+
- split (str): The split of the dataset to use (default='test'). options: 'train', 'validation', 'test'
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
- dataloader (DataLoader): The DataLoader for the specified dataset.
|
| 51 |
+
"""
|
| 52 |
+
# Load tokenizer and model
|
| 53 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
| 54 |
+
if tokenizer.pad_token_id is None:
|
| 55 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 56 |
+
# Load dataset
|
| 57 |
+
dataset = load_dataset(dataset_name, dataset_config, split=split)
|
| 58 |
+
dataset = dataset.filter(lambda x: len(x["text"]) >= max_length)
|
| 59 |
+
|
| 60 |
+
# Tokenize the dataset
|
| 61 |
+
def tokenize_function(examples):
|
| 62 |
+
return tokenizer(examples['text'], return_special_tokens_mask=True, truncation=True, max_length=max_length)
|
| 63 |
+
|
| 64 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text'])
|
| 65 |
+
|
| 66 |
+
# Create DataLoader
|
| 67 |
+
def collate_fn(batch):
|
| 68 |
+
input_ids = [torch.tensor(example['input_ids']) for example in batch]
|
| 69 |
+
attention_mask = [torch.tensor(example['attention_mask']) for example in batch]
|
| 70 |
+
|
| 71 |
+
# Pad sequences
|
| 72 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
|
| 73 |
+
attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
|
| 74 |
+
|
| 75 |
+
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
| 76 |
+
|
| 77 |
+
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
|
| 78 |
+
|
| 79 |
+
return dataloader
|
| 80 |
+
|
| 81 |
+
def compute_sparse_perplexity(model_name='facebook/opt-125m', dataset_name='wikitext', dataset_config='wikitext-2-raw-v1', batch_size=8, max_length=512, attn_th=0.0, static_thresholds=True, device_map="cuda:0"):
|
| 82 |
+
# Set device
|
| 83 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 84 |
+
print(f'Using device: {device}')
|
| 85 |
+
|
| 86 |
+
# load the activation thresholds
|
| 87 |
+
num_layers = CONFIGS[model_name]['num_layer']
|
| 88 |
+
|
| 89 |
+
sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=attn_th)
|
| 90 |
+
|
| 91 |
+
if not static_thresholds:
|
| 92 |
+
act_threshold_filepath = CONFIGS[model_name]['sp_config']
|
| 93 |
+
sp_thresholds.load_thresholds(act_threshold_filepath)
|
| 94 |
+
print(f'Activation thresholds loaded from {act_threshold_filepath}')
|
| 95 |
+
|
| 96 |
+
print(sp_thresholds.activation_threshold)
|
| 97 |
+
|
| 98 |
+
# Load tokenizer and model
|
| 99 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
| 100 |
+
model = SparseOPTForCausalLM.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
|
| 101 |
+
model.eval()
|
| 102 |
+
|
| 103 |
+
# # Load dataset
|
| 104 |
+
dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length)
|
| 105 |
+
|
| 106 |
+
perplexity = compute_perplexity(model, dataloader, device)
|
| 107 |
+
|
| 108 |
+
return perplexity
|
| 109 |
+
|
| 110 |
+
def compute_perplexity_data_collection(model_name='facebook/opt-125m', dataset_name='wikitext', dataset_config='wikitext-2-raw-v1', batch_size=8, max_length=512, device_map="cuda:0"):
|
| 111 |
+
# Set device
|
| 112 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 113 |
+
print(f'Using device: {device}')
|
| 114 |
+
|
| 115 |
+
# Load dataset
|
| 116 |
+
dataset = load_dataset(dataset_name, dataset_config, split='test')
|
| 117 |
+
dataset = dataset.filter(lambda x: len(x["text"]) >= 512)
|
| 118 |
+
dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length)
|
| 119 |
+
|
| 120 |
+
attn_thresholds = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
|
| 121 |
+
# attn_thresholds = [1, 0.5, 0.2]
|
| 122 |
+
|
| 123 |
+
print(f"Computing perplexity for the following attention thresholds: {attn_thresholds}")
|
| 124 |
+
|
| 125 |
+
# load the model
|
| 126 |
+
num_layers = CONFIGS[model_name]['num_layer']
|
| 127 |
+
|
| 128 |
+
sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=0.1)
|
| 129 |
+
model = SparseOPTForCausalLM.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
|
| 130 |
+
model.eval()
|
| 131 |
+
|
| 132 |
+
results = []
|
| 133 |
+
for attn_th in attn_thresholds:
|
| 134 |
+
|
| 135 |
+
print(f'Computing perplexity for attn top k: {attn_th}')
|
| 136 |
+
|
| 137 |
+
# update the model with new threshold
|
| 138 |
+
model = _update_model_attn_thresholds(model, attn_th)
|
| 139 |
+
|
| 140 |
+
# compute and store the perplexity
|
| 141 |
+
perplexity = compute_perplexity(model, dataloader, device)
|
| 142 |
+
print(f'Perplexity: {perplexity:.2f}\n')
|
| 143 |
+
results.append({
|
| 144 |
+
'model': model_name,
|
| 145 |
+
'top_k': attn_th,
|
| 146 |
+
'perplexity': perplexity
|
| 147 |
+
})
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# save the results to a csv file
|
| 151 |
+
results_df = pd.DataFrame(results)
|
| 152 |
+
model_name_str = extract_model_name(model_name)
|
| 153 |
+
|
| 154 |
+
# save the results to a csv file in the results directory
|
| 155 |
+
results_df.to_csv(f'{results_dir}/sparse_perplexity_results_{model_name_str}_topk.csv', index=False)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def display_model_menu():
|
| 159 |
+
"""
|
| 160 |
+
Displays a numbered menu of available OPT models and prompts the user to make a selection.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
- selected_model (str): The Hugging Face identifier of the selected model.
|
| 164 |
+
"""
|
| 165 |
+
print("Available OPT Models:")
|
| 166 |
+
for idx, model in enumerate(MODELS, 1):
|
| 167 |
+
print(f"{idx}. {model}")
|
| 168 |
+
|
| 169 |
+
while True:
|
| 170 |
+
try:
|
| 171 |
+
choice = input("\nEnter the number corresponding to the model you want to evaluate (e.g., 1): ")
|
| 172 |
+
if choice.lower() in ['q', 'quit', 'exit']:
|
| 173 |
+
print("Exiting the program.")
|
| 174 |
+
sys.exit(0)
|
| 175 |
+
choice = int(choice)
|
| 176 |
+
if 1 <= choice <= len(MODELS):
|
| 177 |
+
selected_model = MODELS[choice - 1]
|
| 178 |
+
print(f"\nYou have selected: {selected_model}\n")
|
| 179 |
+
return selected_model
|
| 180 |
+
else:
|
| 181 |
+
print(f"Please enter a number between 1 and {len(MODELS)}.")
|
| 182 |
+
except ValueError:
|
| 183 |
+
print("Invalid input. Please enter a valid number.")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _interactive_mode():
|
| 187 |
+
selected_model = display_model_menu()
|
| 188 |
+
|
| 189 |
+
# Optional: Allow user to adjust batch size and max sequence length
|
| 190 |
+
try:
|
| 191 |
+
batch_size_input = input("Enter batch size (default=8): ").strip()
|
| 192 |
+
batch_size = int(batch_size_input) if batch_size_input else 8
|
| 193 |
+
except ValueError:
|
| 194 |
+
print("Invalid input for batch size. Using default value of 8.")
|
| 195 |
+
batch_size = 8
|
| 196 |
+
|
| 197 |
+
max_length = 512
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
data_collection = input("Do you want to collect data for different activation thresholds? (y/n): ").strip()
|
| 201 |
+
data_collection = True if data_collection.lower() == 'y' else False
|
| 202 |
+
except ValueError:
|
| 203 |
+
print("Invalid input for data collection. Using default value of False.")
|
| 204 |
+
data_collection = False
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# select device
|
| 208 |
+
device_map = input("Enter device (cuda:0/auto) [default=cuda:0]: ").strip()
|
| 209 |
+
if not device_map:
|
| 210 |
+
device_map = "cuda:0"
|
| 211 |
+
|
| 212 |
+
# select attention threshold
|
| 213 |
+
attn_th = 0.0
|
| 214 |
+
if not data_collection:
|
| 215 |
+
try:
|
| 216 |
+
attn_th = input("Enter activation threshold for attention layers: ").strip()
|
| 217 |
+
attn_th = float(attn_th) if attn_th else 0.0
|
| 218 |
+
except ValueError:
|
| 219 |
+
print("Invalid input for attention threshold. Using default value of 0.")
|
| 220 |
+
attn_th = 0.0
|
| 221 |
+
|
| 222 |
+
return selected_model, batch_size, max_length, data_collection, device_map, attn_th
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def arg_parser():
|
| 226 |
+
parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation')
|
| 227 |
+
parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate')
|
| 228 |
+
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation')
|
| 229 |
+
parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length')
|
| 230 |
+
parser.add_argument('--attn_th', type=float, default=0.0, help='Activation threshold for attention layers')
|
| 231 |
+
parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds')
|
| 232 |
+
parser.add_argument('--device_map', type=str, default='cuda:0', help='Device to use for evaluation')
|
| 233 |
+
parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection')
|
| 234 |
+
|
| 235 |
+
return parser.parse_args()
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
"""
|
| 239 |
+
Main function to execute the perplexity computation with user-selected OPT model.
|
| 240 |
+
"""
|
| 241 |
+
print("=== OPT Models Perplexity Evaluation ===\n")
|
| 242 |
+
args = arg_parser()
|
| 243 |
+
|
| 244 |
+
if args.interactive:
|
| 245 |
+
selected_model, batch_size, max_length, data_collection, device_map, attn_th = _interactive_mode()
|
| 246 |
+
|
| 247 |
+
else:
|
| 248 |
+
selected_model, batch_size, max_length, data_collection, device_map, attn_th = MODELS[args.model_index-1], args.batch_size, args.max_length, args.data_collection, args.device_map, args.attn_th
|
| 249 |
+
print(f"Selected model: {selected_model}, batch size: {batch_size}, max length: {max_length}, attn_th: {attn_th}, data_collection: {data_collection}, device: {device_map}")
|
| 250 |
+
|
| 251 |
+
if data_collection:
|
| 252 |
+
print("\nStarting data collection...\n")
|
| 253 |
+
compute_perplexity_data_collection(model_name=selected_model, batch_size=batch_size, max_length=max_length, device_map=device_map)
|
| 254 |
+
print("\nData collection complete.\n")
|
| 255 |
+
|
| 256 |
+
else:
|
| 257 |
+
print("\nStarting perplexity computation...\n")
|
| 258 |
+
perplexity = compute_sparse_perplexity(model_name=selected_model, batch_size=batch_size, max_length=max_length, attn_th=attn_th, device_map=device_map)
|
| 259 |
+
print(f"\n=== Perplexity Results ===")
|
| 260 |
+
print(f"Model: {selected_model}")
|
| 261 |
+
print(f"Perplexity: {perplexity:.2f}\n")
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
main()
|
HybridTensor/benchmarks/select_block_decode.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from tests.test_select_block import create_block, Config, SparseConfig
|
| 4 |
+
import csv
|
| 5 |
+
import time
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from flash_attn.utils.generation import InferenceParams
|
| 9 |
+
from HybridTensor.utils.utils import arg_parser, _get_device, sparse_index, generate_random_BH_index, get_gpu_name
|
| 10 |
+
from HybridTensor.utils.profiling import cuda_profiler
|
| 11 |
+
import math
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
def run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype):
|
| 15 |
+
config = Config()
|
| 16 |
+
sp_config = SparseConfig()
|
| 17 |
+
sp_config.attn_topk = attn_topk
|
| 18 |
+
|
| 19 |
+
config.hidden_size = args.in_features
|
| 20 |
+
config.num_attention_heads = args.in_features // 128
|
| 21 |
+
config.use_heuristic = False # use pre-compiled heuristic or complie new one during runtime
|
| 22 |
+
|
| 23 |
+
# Test create_block
|
| 24 |
+
sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype)
|
| 25 |
+
sparse_block.eval()
|
| 26 |
+
sparse_block.mlp_topk = index_size
|
| 27 |
+
|
| 28 |
+
regular_config = config
|
| 29 |
+
regular_config.att_sparse = False
|
| 30 |
+
regular_config.mlp_sparse = False
|
| 31 |
+
regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype)
|
| 32 |
+
regular_block.eval()
|
| 33 |
+
|
| 34 |
+
# inference simulation with select block
|
| 35 |
+
max_seqlen = seq_len + 16
|
| 36 |
+
max_batch_size = batch_size
|
| 37 |
+
in_features = args.in_features
|
| 38 |
+
head_dim = 128
|
| 39 |
+
|
| 40 |
+
inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size)
|
| 41 |
+
process_group = None
|
| 42 |
+
sequence_parallel = False
|
| 43 |
+
|
| 44 |
+
# for testing and debugging
|
| 45 |
+
heads = config.num_attention_heads
|
| 46 |
+
selected_heads = heads // 2
|
| 47 |
+
|
| 48 |
+
# Create a static index vector (length equals total columns in B).
|
| 49 |
+
total_neurons = args.in_features * 4
|
| 50 |
+
test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32)
|
| 51 |
+
active_indices = sparse_index(args.index_size, total_neurons)[0]
|
| 52 |
+
test_index_vec[:args.index_size] = active_indices
|
| 53 |
+
if args.index_size < total_neurons:
|
| 54 |
+
test_index_vec[args.index_size:] = 0 # Fill the rest with dummy values.
|
| 55 |
+
|
| 56 |
+
# test_index_vec = sparse_index(args.in_features, args.in_features*4)[0].cuda()
|
| 57 |
+
test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads)
|
| 58 |
+
test_index_size = args.index_size
|
| 59 |
+
|
| 60 |
+
mixer_kwargs = (
|
| 61 |
+
{"seqlen": seq_len}
|
| 62 |
+
if process_group is not None and sequence_parallel
|
| 63 |
+
else {}
|
| 64 |
+
)
|
| 65 |
+
if inference_params is not None:
|
| 66 |
+
mixer_kwargs["inference_params"] = inference_params
|
| 67 |
+
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
# prefill stage
|
| 70 |
+
original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16)
|
| 71 |
+
|
| 72 |
+
# Test prefill
|
| 73 |
+
output_sparse = sparse_block(original_seq, mixer_kwargs=mixer_kwargs)
|
| 74 |
+
output_regular = regular_block(original_seq, mixer_kwargs=mixer_kwargs)
|
| 75 |
+
|
| 76 |
+
# need to update inference_params to reflect the new sequence length
|
| 77 |
+
mixer_kwargs["inference_params"].seqlen_offset = seq_len
|
| 78 |
+
|
| 79 |
+
# Decode stage
|
| 80 |
+
input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16)
|
| 81 |
+
|
| 82 |
+
out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs)
|
| 83 |
+
|
| 84 |
+
mixer_kwargs["inference_params"].seqlen_offset = seq_len
|
| 85 |
+
|
| 86 |
+
out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs)
|
| 87 |
+
|
| 88 |
+
# mesure decode stage time in ms
|
| 89 |
+
# print("Without CUDA Graphs")
|
| 90 |
+
# out_decode_regular, regular_time = cuda_profiler(regular_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2)
|
| 91 |
+
# print(f"Regular time: {regular_time} ms")
|
| 92 |
+
|
| 93 |
+
# out_decode_sparse, sparse_time = cuda_profiler(sparse_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2)
|
| 94 |
+
# print(f"Sparse time: {sparse_time} ms")
|
| 95 |
+
|
| 96 |
+
# speedup = regular_time / sparse_time
|
| 97 |
+
# print(f"Speedup: {speedup}")
|
| 98 |
+
|
| 99 |
+
# --- CUDA Graph Capture for Decode Stage ---
|
| 100 |
+
# Allocate static buffer for regular block (shape assumed fixed)
|
| 101 |
+
input_x_static = input_x.clone()
|
| 102 |
+
output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype)
|
| 103 |
+
|
| 104 |
+
# Capture regular block graph
|
| 105 |
+
_ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs)
|
| 106 |
+
torch.cuda.synchronize()
|
| 107 |
+
graph_regular = torch.cuda.CUDAGraph()
|
| 108 |
+
with torch.cuda.graph(graph_regular):
|
| 109 |
+
res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs)
|
| 110 |
+
if isinstance(res, tuple):
|
| 111 |
+
res = res[0]
|
| 112 |
+
output_regular_static.copy_(res)
|
| 113 |
+
|
| 114 |
+
# For the sparse block, run a dummy call to determine its output shape.
|
| 115 |
+
# Also, reset the inference parameter to ensure consistent behavior.
|
| 116 |
+
mixer_kwargs["inference_params"].seqlen_offset = seq_len
|
| 117 |
+
temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs)
|
| 118 |
+
if isinstance(temp, tuple):
|
| 119 |
+
temp = temp[0]
|
| 120 |
+
# print("Captured sparse block output shape:", temp.shape)
|
| 121 |
+
# Allocate static buffer matching the dummy run's shape.
|
| 122 |
+
output_sparse_static = torch.empty_like(temp)
|
| 123 |
+
# print("output_sparse_static shape:", output_sparse_static.shape)
|
| 124 |
+
torch.cuda.synchronize()
|
| 125 |
+
|
| 126 |
+
mixer_kwargs["inference_params"].seqlen_offset = seq_len
|
| 127 |
+
graph_sparse = torch.cuda.CUDAGraph()
|
| 128 |
+
with torch.cuda.graph(graph_sparse):
|
| 129 |
+
res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs)
|
| 130 |
+
if isinstance(res, tuple):
|
| 131 |
+
res = res[0]
|
| 132 |
+
output_sparse_static.copy_(res)
|
| 133 |
+
|
| 134 |
+
# Warmup CUDA Graph replays
|
| 135 |
+
for _ in range(5):
|
| 136 |
+
graph_regular.replay()
|
| 137 |
+
graph_sparse.replay()
|
| 138 |
+
torch.cuda.synchronize()
|
| 139 |
+
|
| 140 |
+
# --- Measure CUDA Graph Replay Latency ---
|
| 141 |
+
num_replays = 10
|
| 142 |
+
|
| 143 |
+
start = time.time()
|
| 144 |
+
for _ in range(num_replays):
|
| 145 |
+
graph_regular.replay()
|
| 146 |
+
torch.cuda.synchronize()
|
| 147 |
+
regular_graph_time = (time.time() - start) * 1000 / num_replays
|
| 148 |
+
|
| 149 |
+
start = time.time()
|
| 150 |
+
for _ in range(num_replays):
|
| 151 |
+
graph_sparse.replay()
|
| 152 |
+
torch.cuda.synchronize()
|
| 153 |
+
sparse_graph_time = (time.time() - start) * 1000 / num_replays
|
| 154 |
+
speedup = regular_graph_time / sparse_graph_time
|
| 155 |
+
# print()
|
| 156 |
+
# print("With CUDA Graphs")
|
| 157 |
+
# print(f"Regular block time (CUDA Graphs): {regular_graph_time} ms")
|
| 158 |
+
# print(f"Sparse block time (CUDA Graphs): {sparse_graph_time} ms")
|
| 159 |
+
# print(f"Speedup (CUDA Graphs): {speedup}")
|
| 160 |
+
|
| 161 |
+
return regular_graph_time, sparse_graph_time, speedup
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
|
| 165 |
+
args = arg_parser()
|
| 166 |
+
device = _get_device(0)
|
| 167 |
+
dtype = torch.float16
|
| 168 |
+
gpu_name = get_gpu_name()
|
| 169 |
+
|
| 170 |
+
# Parameter grids.
|
| 171 |
+
# batch_sizes = [1, 4, 8, 16]
|
| 172 |
+
# seq_lengths = [128, 512]
|
| 173 |
+
# index_sizes = [512, 1024, 2048, 4096]
|
| 174 |
+
# attn_topks = [0.3, 0.4, 0.5]
|
| 175 |
+
|
| 176 |
+
batch_sizes = [1, 8, 16, 32]
|
| 177 |
+
seq_lengths = [1024, 2048]
|
| 178 |
+
# index_sizes = [512, 1024, 2048, 4096, 8192]
|
| 179 |
+
index_size_p = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
|
| 180 |
+
total_neurons = args.in_features * 4
|
| 181 |
+
|
| 182 |
+
# Calculate initial index_size values
|
| 183 |
+
index_sizes = [int(total_neurons * i) for i in index_size_p]
|
| 184 |
+
|
| 185 |
+
# Round up to the nearest multiple of 128 if necessary
|
| 186 |
+
index_sizes = [math.ceil(size / 128) * 128 if size % 128 != 0 else size for size in index_sizes]
|
| 187 |
+
|
| 188 |
+
attn_topks = [0.3, 0.4, 0.5]
|
| 189 |
+
|
| 190 |
+
# Calculate total number of simulations.
|
| 191 |
+
total_runs = len(batch_sizes) * len(seq_lengths) * len(index_sizes) * len(attn_topks)
|
| 192 |
+
output_file = f"results/simulations/{gpu_name}_select_block_{args.in_features}_inference_sim.csv"
|
| 193 |
+
|
| 194 |
+
with open(output_file, mode='w', newline='') as csv_file:
|
| 195 |
+
fieldnames = ["in_features", "batch_size", "seq_len", "index_size", "neuron_activation", "attn_topk",
|
| 196 |
+
"regular_graph_time_ms", "sparse_graph_time_ms", "speedup"]
|
| 197 |
+
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
|
| 198 |
+
writer.writeheader()
|
| 199 |
+
|
| 200 |
+
# Iterate over all combinations with tqdm progress bar.
|
| 201 |
+
for batch_size in tqdm(batch_sizes, desc="Batch Sizes"):
|
| 202 |
+
for seq_len in seq_lengths:
|
| 203 |
+
for index_size in index_sizes:
|
| 204 |
+
for attn_topk in attn_topks:
|
| 205 |
+
reg_time, spa_time, speedup = run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype)
|
| 206 |
+
writer.writerow({
|
| 207 |
+
"in_features": args.in_features,
|
| 208 |
+
"batch_size": batch_size,
|
| 209 |
+
"seq_len": seq_len,
|
| 210 |
+
"index_size": index_size,
|
| 211 |
+
"neuron_activation": index_size / total_neurons,
|
| 212 |
+
"attn_topk": attn_topk,
|
| 213 |
+
"regular_graph_time_ms": reg_time,
|
| 214 |
+
"sparse_graph_time_ms": spa_time,
|
| 215 |
+
"speedup": speedup
|
| 216 |
+
})
|
| 217 |
+
csv_file.flush()
|
| 218 |
+
print(f"Simulation complete. Results saved to {output_file}")
|
HybridTensor/models/__pycache__/create_sparse_model.cpython-310.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
HybridTensor/models/__pycache__/create_sparse_model.cpython-39.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
HybridTensor/models/__pycache__/helper.cpython-310.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
HybridTensor/models/__pycache__/helper.cpython-39.pyc
ADDED
|
Binary file (4.65 kB). View file
|
|
|
HybridTensor/models/__pycache__/llama.cpython-39.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
HybridTensor/models/__pycache__/opt.cpython-310.pyc
ADDED
|
Binary file (4.9 kB). View file
|
|
|
HybridTensor/models/__pycache__/opt.cpython-39.pyc
ADDED
|
Binary file (5.19 kB). View file
|
|
|
HybridTensor/models/create_sparse_model.py
ADDED
|
@@ -0,0 +1,854 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
|
| 8 |
+
from transformers import GPT2Config
|
| 9 |
+
from collections import namedtuple
|
| 10 |
+
from HybridTensor.modules.SelectiveMHA import SMHA, SelectMHA, ParallelSelectMHA, MHARouter, ParallelMHARouter
|
| 11 |
+
from HybridTensor.modules.SelectiveMLP import SelectiveMLP, ParallelSelectiveMLP, MLPRouter, ParallelMLPRouter
|
| 12 |
+
from HybridTensor.modules.SelectiveBlock import SelectBlock
|
| 13 |
+
# from HybridTensor.modules.SelectiveBlock_v1 import SelectBlock
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from flash_attn.utils.distributed import (
|
| 16 |
+
all_gather,
|
| 17 |
+
all_gather_raw,
|
| 18 |
+
get_dim_for_local_rank,
|
| 19 |
+
sync_shared_params,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from collections.abc import Sequence
|
| 23 |
+
from flash_attn.modules.mha import MHA, ParallelMHA
|
| 24 |
+
from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP, GatedMlp, ParallelGatedMlp, Mlp, ParallelMLP
|
| 25 |
+
from flash_attn.ops.activations import sqrelu_fwd
|
| 26 |
+
from flash_attn.modules.block import Block
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
| 30 |
+
except ImportError:
|
| 31 |
+
layer_norm_fn, RMSNorm = None, None
|
| 32 |
+
|
| 33 |
+
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
| 34 |
+
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
| 35 |
+
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
| 36 |
+
from flash_attn.utils.generation import GenerationMixin
|
| 37 |
+
from flash_attn.models.opt import remap_state_dict_hf_opt
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear
|
| 41 |
+
except ImportError:
|
| 42 |
+
ColumnParallelLinear = None
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
|
| 46 |
+
except ImportError:
|
| 47 |
+
FusedDenseSqreluDense = None
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
| 51 |
+
except ImportError:
|
| 52 |
+
layer_norm_fn, RMSNorm = None, None
|
| 53 |
+
|
| 54 |
+
from HybridTensor.models.helper import remap_state_dict_gpt2, shard_state_dict_tp
|
| 55 |
+
|
| 56 |
+
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
| 57 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 58 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 59 |
+
attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0
|
| 60 |
+
softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power))
|
| 61 |
+
softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0)
|
| 62 |
+
if config.scale_attn_by_inverse_layer_idx:
|
| 63 |
+
assert layer_idx is not None
|
| 64 |
+
softmax_scale /= float(layer_idx + 1)
|
| 65 |
+
dwconv = getattr(config, "attn_dwconv", False)
|
| 66 |
+
if dwconv:
|
| 67 |
+
assert process_group is None, "TensorParallel MHA does not support dwconv yet"
|
| 68 |
+
qkv_proj_bias = getattr(config, "qkv_proj_bias", True)
|
| 69 |
+
out_proj_bias = getattr(config, "out_proj_bias", True)
|
| 70 |
+
rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim)
|
| 71 |
+
rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
|
| 72 |
+
rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
|
| 73 |
+
rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
|
| 74 |
+
use_alibi = getattr(config, "use_alibi", False)
|
| 75 |
+
use_triton = getattr(config, "use_triton", True) # toggle cuda or triton decode kernels
|
| 76 |
+
window_size = getattr(config, "window_size", (-1, -1))
|
| 77 |
+
use_flash_attn = getattr(config, "use_flash_attn", False)
|
| 78 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 79 |
+
if not fused_bias_fc:
|
| 80 |
+
assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
|
| 81 |
+
|
| 82 |
+
mlp_sparse = getattr(config, "mlp_sparse", False)
|
| 83 |
+
att_sparse = getattr(config, "att_sparse", False)
|
| 84 |
+
num_heads = getattr(config, "num_attention_heads", None)
|
| 85 |
+
n_head_kv = getattr(config, "n_head_kv", num_heads)
|
| 86 |
+
|
| 87 |
+
if num_heads != n_head_kv:
|
| 88 |
+
att_sparse = False
|
| 89 |
+
|
| 90 |
+
if process_group is None:
|
| 91 |
+
mha_cls = SMHA # SelectMHA if att_sparse else MHA
|
| 92 |
+
else:
|
| 93 |
+
mha_cls = ParallelSelectMHA if att_sparse else ParallelMHA
|
| 94 |
+
|
| 95 |
+
# mha_cls = SelectMHA if process_group is None else ParallelSelectMHA
|
| 96 |
+
serial_kwargs = (
|
| 97 |
+
{"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {}
|
| 98 |
+
)
|
| 99 |
+
parallel_kwargs = (
|
| 100 |
+
{
|
| 101 |
+
"process_group": process_group,
|
| 102 |
+
"sequence_parallel": getattr(config, "sequence_parallel", False),
|
| 103 |
+
}
|
| 104 |
+
if process_group is not None
|
| 105 |
+
else {}
|
| 106 |
+
)
|
| 107 |
+
num_heads_kv = getattr(config, "n_head_kv", None)
|
| 108 |
+
mixer_cls = partial(
|
| 109 |
+
mha_cls,
|
| 110 |
+
num_heads=config.num_attention_heads,
|
| 111 |
+
num_heads_kv=num_heads_kv,
|
| 112 |
+
qkv_proj_bias=qkv_proj_bias,
|
| 113 |
+
out_proj_bias=out_proj_bias,
|
| 114 |
+
dropout=config.attn_pdrop,
|
| 115 |
+
softmax_scale=softmax_scale,
|
| 116 |
+
causal=True,
|
| 117 |
+
layer_idx=layer_idx,
|
| 118 |
+
rotary_emb_dim=rotary_emb_dim,
|
| 119 |
+
rotary_emb_base=rotary_emb_base,
|
| 120 |
+
rotary_emb_scale_base=rotary_emb_scale_base,
|
| 121 |
+
rotary_emb_interleaved=rotary_emb_interleaved,
|
| 122 |
+
use_alibi=use_alibi,
|
| 123 |
+
window_size=window_size,
|
| 124 |
+
use_flash_attn=use_flash_attn,
|
| 125 |
+
**serial_kwargs,
|
| 126 |
+
**parallel_kwargs,
|
| 127 |
+
**factory_kwargs,
|
| 128 |
+
)
|
| 129 |
+
return mixer_cls
|
| 130 |
+
|
| 131 |
+
def create_mlp_cls_old(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
| 132 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 133 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
| 134 |
+
fused_mlp = getattr(config, "fused_mlp", False)
|
| 135 |
+
if fused_mlp:
|
| 136 |
+
assert config.activation_function in [
|
| 137 |
+
"gelu_new",
|
| 138 |
+
"gelu_fast",
|
| 139 |
+
"gelu_approx",
|
| 140 |
+
"gelu_pytorch_tanh",
|
| 141 |
+
"relu",
|
| 142 |
+
"sqrelu",
|
| 143 |
+
]
|
| 144 |
+
assert fused_mlp == True, "Not supported not fused mlp for now"
|
| 145 |
+
|
| 146 |
+
mlp_sparse = getattr(config, "mlp_sparse", False)
|
| 147 |
+
use_heuristic = getattr(config, "use_heuristic", True)
|
| 148 |
+
|
| 149 |
+
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
| 150 |
+
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
| 151 |
+
if isinstance(mlp_checkpoint_lvl, Sequence):
|
| 152 |
+
assert layer_idx is not None
|
| 153 |
+
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
| 154 |
+
|
| 155 |
+
if fused_mlp:
|
| 156 |
+
if FusedMLP is None:
|
| 157 |
+
raise ImportError("fused_dense is not installed")
|
| 158 |
+
# activation = (
|
| 159 |
+
# "gelu_approx"
|
| 160 |
+
# if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
|
| 161 |
+
# else "relu"
|
| 162 |
+
# )
|
| 163 |
+
|
| 164 |
+
if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]:
|
| 165 |
+
activation = "gelu_approx"
|
| 166 |
+
else:
|
| 167 |
+
activation = "relu" # config.activation_function
|
| 168 |
+
|
| 169 |
+
if process_group is None:
|
| 170 |
+
mlp_cls = SelectiveMLP if mlp_sparse else FusedMLP
|
| 171 |
+
else:
|
| 172 |
+
mlp_cls = ParallelSelectiveMLP if mlp_sparse else ParallelFusedMLP
|
| 173 |
+
|
| 174 |
+
parallel_kwargs = (
|
| 175 |
+
{
|
| 176 |
+
"process_group": process_group,
|
| 177 |
+
"sequence_parallel": getattr(config, "sequence_parallel", True),
|
| 178 |
+
}
|
| 179 |
+
if process_group is not None
|
| 180 |
+
else {}
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
sparsity_kwargs = (
|
| 184 |
+
{
|
| 185 |
+
"use_heuristic": use_heuristic,
|
| 186 |
+
}
|
| 187 |
+
if mlp_sparse
|
| 188 |
+
else {}
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
mlp_cls = partial(
|
| 192 |
+
mlp_cls,
|
| 193 |
+
hidden_features=inner_dim,
|
| 194 |
+
activation=activation,
|
| 195 |
+
checkpoint_lvl=mlp_checkpoint_lvl,
|
| 196 |
+
# layer_idx=layer_idx,
|
| 197 |
+
**parallel_kwargs,
|
| 198 |
+
**factory_kwargs,
|
| 199 |
+
**sparsity_kwargs,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
else:
|
| 203 |
+
raise RuntimeError("MLP type not supported")
|
| 204 |
+
return mlp_cls
|
| 205 |
+
|
| 206 |
+
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
| 207 |
+
"""
|
| 208 |
+
Create an MLP class that supports both sparse MLPs (via fused mlp) and GatedMLPs.
|
| 209 |
+
If the activation function is one of "glu", "swiglu", or "geglu", then GatedMlp is used
|
| 210 |
+
(and mlp_sparse is ignored). Otherwise, fused_mlp is used to decide between sparse and
|
| 211 |
+
dense implementations.
|
| 212 |
+
"""
|
| 213 |
+
from functools import partial
|
| 214 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 215 |
+
mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True)
|
| 216 |
+
mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# Check for gated activations
|
| 220 |
+
if config.activation_function in ["glu", "swiglu", "geglu"]:
|
| 221 |
+
# For gated activations we do not support sparsity yet.
|
| 222 |
+
activation = (
|
| 223 |
+
F.sigmoid if config.activation_function == "glu"
|
| 224 |
+
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
| 225 |
+
)
|
| 226 |
+
mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
|
| 227 |
+
parallel_kwargs = (
|
| 228 |
+
{"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)}
|
| 229 |
+
if process_group is not None else {}
|
| 230 |
+
)
|
| 231 |
+
mlp_multiple_of = getattr(config, "mlp_multiple_of", 128)
|
| 232 |
+
mlp_cls = partial(
|
| 233 |
+
mlp_cls,
|
| 234 |
+
hidden_features=config.n_inner,
|
| 235 |
+
activation=activation,
|
| 236 |
+
bias1=mlp_fc1_bias,
|
| 237 |
+
bias2=mlp_fc2_bias,
|
| 238 |
+
multiple_of=mlp_multiple_of,
|
| 239 |
+
**parallel_kwargs,
|
| 240 |
+
**factory_kwargs,
|
| 241 |
+
)
|
| 242 |
+
return mlp_cls
|
| 243 |
+
|
| 244 |
+
# For non-gated activations:
|
| 245 |
+
fused_mlp = getattr(config, "fused_mlp", False)
|
| 246 |
+
fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False)
|
| 247 |
+
if fused_dense_sqrelu_dense:
|
| 248 |
+
assert config.activation_function == "sqrelu", (
|
| 249 |
+
"fused_dense_sqrelu_dense only supports approximate activation_function sqrelu"
|
| 250 |
+
)
|
| 251 |
+
assert not (fused_dense_sqrelu_dense and fused_mlp)
|
| 252 |
+
|
| 253 |
+
if fused_mlp:
|
| 254 |
+
# Ensure valid activation function.
|
| 255 |
+
assert config.activation_function in [
|
| 256 |
+
"gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu"
|
| 257 |
+
]
|
| 258 |
+
# Support checkpoint level (possibly a list)
|
| 259 |
+
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
| 260 |
+
if isinstance(mlp_checkpoint_lvl, (list, tuple)):
|
| 261 |
+
assert layer_idx is not None
|
| 262 |
+
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
| 263 |
+
# Choose activation string.
|
| 264 |
+
if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]:
|
| 265 |
+
activation = "gelu_approx"
|
| 266 |
+
else:
|
| 267 |
+
activation = "relu"
|
| 268 |
+
# Determine inner dim.
|
| 269 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
| 270 |
+
mlp_sparse = getattr(config, "mlp_sparse", False)
|
| 271 |
+
use_heuristic = getattr(config, "use_heuristic", True)
|
| 272 |
+
if process_group is None:
|
| 273 |
+
mlp_cls = SelectiveMLP if mlp_sparse else FusedMLP
|
| 274 |
+
else:
|
| 275 |
+
mlp_cls = ParallelSelectiveMLP if mlp_sparse else ParallelFusedMLP
|
| 276 |
+
parallel_kwargs = (
|
| 277 |
+
{"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)}
|
| 278 |
+
if process_group is not None else {}
|
| 279 |
+
)
|
| 280 |
+
sparsity_kwargs = {"use_heuristic": use_heuristic} if mlp_sparse else {}
|
| 281 |
+
mlp_cls = partial(
|
| 282 |
+
mlp_cls,
|
| 283 |
+
hidden_features=inner_dim,
|
| 284 |
+
activation=activation,
|
| 285 |
+
checkpoint_lvl=mlp_checkpoint_lvl,
|
| 286 |
+
bias1=mlp_fc1_bias,
|
| 287 |
+
bias2=mlp_fc2_bias,
|
| 288 |
+
**parallel_kwargs,
|
| 289 |
+
**factory_kwargs,
|
| 290 |
+
**sparsity_kwargs,
|
| 291 |
+
)
|
| 292 |
+
return mlp_cls
|
| 293 |
+
|
| 294 |
+
elif fused_dense_sqrelu_dense:
|
| 295 |
+
if process_group is not None:
|
| 296 |
+
assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense"
|
| 297 |
+
assert FusedDenseSqreluDense is not None
|
| 298 |
+
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
| 299 |
+
if isinstance(mlp_checkpoint_lvl, (list, tuple)):
|
| 300 |
+
assert layer_idx is not None
|
| 301 |
+
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
| 302 |
+
mlp_cls = partial(
|
| 303 |
+
FusedDenseSqreluDense,
|
| 304 |
+
hidden_features=config.n_inner,
|
| 305 |
+
checkpoint_lvl=mlp_checkpoint_lvl,
|
| 306 |
+
**factory_kwargs,
|
| 307 |
+
)
|
| 308 |
+
return mlp_cls
|
| 309 |
+
|
| 310 |
+
else:
|
| 311 |
+
# Non-fused, non-sparse branch.
|
| 312 |
+
assert config.activation_function in [
|
| 313 |
+
"gelu", "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu"
|
| 314 |
+
]
|
| 315 |
+
if config.activation_function == "relu":
|
| 316 |
+
activation = partial(F.relu, inplace=True)
|
| 317 |
+
elif config.activation_function == "sqrelu":
|
| 318 |
+
activation = sqrelu_fwd
|
| 319 |
+
else:
|
| 320 |
+
approximate = "tanh" if config.activation_function in [
|
| 321 |
+
"gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"
|
| 322 |
+
] else "none"
|
| 323 |
+
activation = partial(F.gelu, approximate=approximate)
|
| 324 |
+
mlp_sparse = getattr(config, "mlp_sparse", False)
|
| 325 |
+
mlp_cls = Mlp if process_group is None else ParallelMLP
|
| 326 |
+
parallel_kwargs = (
|
| 327 |
+
{"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)}
|
| 328 |
+
if process_group is not None else {}
|
| 329 |
+
)
|
| 330 |
+
mlp_cls = partial(
|
| 331 |
+
mlp_cls,
|
| 332 |
+
hidden_features=config.n_inner,
|
| 333 |
+
activation=activation,
|
| 334 |
+
bias1=mlp_fc1_bias,
|
| 335 |
+
bias2=mlp_fc2_bias,
|
| 336 |
+
**parallel_kwargs,
|
| 337 |
+
**factory_kwargs,
|
| 338 |
+
)
|
| 339 |
+
return mlp_cls
|
| 340 |
+
|
| 341 |
+
def create_mlp_router_cls(config, sp_config = None, layer_idx=None, process_group=None, device=None, dtype=None):
|
| 342 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 343 |
+
num_neurons = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
| 344 |
+
|
| 345 |
+
# this can be made different per layer by adding mlp_low_rank_dim_{layer_idx} in the sp_config
|
| 346 |
+
low_rank_dim = getattr(sp_config, "mlp_low_rank_dim", 1024)
|
| 347 |
+
|
| 348 |
+
# per layer activation threshold
|
| 349 |
+
act_th = getattr(config, "mlp_act_th", 0.5)
|
| 350 |
+
|
| 351 |
+
if process_group is None:
|
| 352 |
+
mlp_router_cls = MLPRouter
|
| 353 |
+
else:
|
| 354 |
+
mlp_router_cls = ParallelMLPRouter
|
| 355 |
+
|
| 356 |
+
parallel_kwargs = (
|
| 357 |
+
{
|
| 358 |
+
"process_group": process_group,
|
| 359 |
+
"sequence_parallel": getattr(config, "sequence_parallel", True),
|
| 360 |
+
}
|
| 361 |
+
if process_group is not None
|
| 362 |
+
else {}
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
mlp_router_cls = partial(mlp_router_cls,
|
| 366 |
+
low_rank_dim = low_rank_dim,
|
| 367 |
+
out_dim = num_neurons,
|
| 368 |
+
act_th = act_th,
|
| 369 |
+
**parallel_kwargs,
|
| 370 |
+
**factory_kwargs)
|
| 371 |
+
|
| 372 |
+
return mlp_router_cls
|
| 373 |
+
|
| 374 |
+
def create_mha_router_cls(config, sp_config = None, layer_idx=None, process_group=None, device=None, dtype=None):
|
| 375 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 376 |
+
num_heads = config.num_attention_heads
|
| 377 |
+
n_head_kv = getattr(config, "n_head_kv", num_heads)
|
| 378 |
+
if num_heads != n_head_kv:
|
| 379 |
+
out_dim = n_head_kv
|
| 380 |
+
else:
|
| 381 |
+
out_dim = num_heads
|
| 382 |
+
|
| 383 |
+
low_rank_dim = getattr(sp_config, "attn_low_rank_dim", 128) # optional, default to 128
|
| 384 |
+
|
| 385 |
+
# per layer activation topk, to make this different per layer, add a different attn_topk_{layer_idx} in the sp_config
|
| 386 |
+
attn_topk = getattr(sp_config, "attn_topk", 0.5)
|
| 387 |
+
if process_group is None:
|
| 388 |
+
mha_router_cls = MHARouter
|
| 389 |
+
else:
|
| 390 |
+
mha_router_cls = ParallelMHARouter
|
| 391 |
+
|
| 392 |
+
parallel_kwargs = (
|
| 393 |
+
{
|
| 394 |
+
"process_group": process_group,
|
| 395 |
+
"sequence_parallel": getattr(config, "sequence_parallel", True),
|
| 396 |
+
}
|
| 397 |
+
if process_group is not None
|
| 398 |
+
else {}
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
mha_router_cls = partial(mha_router_cls,
|
| 403 |
+
low_rank_dim = low_rank_dim,
|
| 404 |
+
out_dim = out_dim,
|
| 405 |
+
top_k = attn_topk,
|
| 406 |
+
**parallel_kwargs,
|
| 407 |
+
**factory_kwargs)
|
| 408 |
+
|
| 409 |
+
return mha_router_cls
|
| 410 |
+
|
| 411 |
+
def create_block(config, sp_config, layer_idx=None, process_group=None, device=None, dtype=None):
|
| 412 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 413 |
+
sequence_parallel = getattr(config, "sequence_parallel", True)
|
| 414 |
+
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
| 415 |
+
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
| 416 |
+
|
| 417 |
+
use_rms_norm = getattr(config, "rms_norm", False)
|
| 418 |
+
norm_cls = partial(
|
| 419 |
+
nn.LayerNorm if not use_rms_norm else RMSNorm,
|
| 420 |
+
eps=config.layer_norm_epsilon,
|
| 421 |
+
**factory_kwargs,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
| 425 |
+
residual_in_fp32 = getattr(config, "residual_in_fp32", False)
|
| 426 |
+
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
|
| 427 |
+
prenorm = getattr(config, "prenorm", True)
|
| 428 |
+
parallel_block = getattr(config, "parallel_block", False)
|
| 429 |
+
mlp_sparse = getattr(config, "mlp_sparse", False)
|
| 430 |
+
att_sparse = getattr(config, "att_sparse", False)
|
| 431 |
+
block_sparse = mlp_sparse or att_sparse
|
| 432 |
+
|
| 433 |
+
if not parallel_block:
|
| 434 |
+
if block_sparse:
|
| 435 |
+
mha_router_cls = create_mha_router_cls(config, sp_config, layer_idx, process_group=process_group, **factory_kwargs) if att_sparse else None
|
| 436 |
+
mlp_router_cls = create_mlp_router_cls(config, sp_config, layer_idx, process_group=process_group, **factory_kwargs) if mlp_sparse else None
|
| 437 |
+
|
| 438 |
+
block = SelectBlock(
|
| 439 |
+
config.hidden_size,
|
| 440 |
+
mixer_cls,
|
| 441 |
+
mlp_cls,
|
| 442 |
+
mlp_router = mlp_router_cls,
|
| 443 |
+
mha_router = mha_router_cls,
|
| 444 |
+
norm_cls=norm_cls,
|
| 445 |
+
prenorm=prenorm,
|
| 446 |
+
resid_dropout1=resid_dropout1,
|
| 447 |
+
resid_dropout2=config.resid_pdrop,
|
| 448 |
+
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
|
| 449 |
+
residual_in_fp32=residual_in_fp32,
|
| 450 |
+
sequence_parallel=sequence_parallel and process_group is not None,
|
| 451 |
+
mark_shared_params=process_group is not None,
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
block = Block(
|
| 455 |
+
config.hidden_size,
|
| 456 |
+
mixer_cls,
|
| 457 |
+
mlp_cls,
|
| 458 |
+
norm_cls=norm_cls,
|
| 459 |
+
prenorm=prenorm,
|
| 460 |
+
resid_dropout1=resid_dropout1,
|
| 461 |
+
resid_dropout2=config.resid_pdrop,
|
| 462 |
+
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
|
| 463 |
+
residual_in_fp32=residual_in_fp32,
|
| 464 |
+
sequence_parallel=sequence_parallel and process_group is not None,
|
| 465 |
+
mark_shared_params=process_group is not None,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
else:
|
| 469 |
+
# not implemented
|
| 470 |
+
raise RuntimeError("ParallelBlock not implemented")
|
| 471 |
+
block.layer_idx = layer_idx
|
| 472 |
+
return block
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class GPTPreTrainedModel(nn.Module):
|
| 476 |
+
"""An abstract class to handle weights initialization and
|
| 477 |
+
a simple interface for dowloading and loading pretrained models.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 481 |
+
super().__init__()
|
| 482 |
+
if not isinstance(config, GPT2Config):
|
| 483 |
+
raise ValueError(
|
| 484 |
+
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
|
| 485 |
+
"To create a model from a Google pretrained model use "
|
| 486 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
| 487 |
+
self.__class__.__name__, self.__class__.__name__
|
| 488 |
+
)
|
| 489 |
+
)
|
| 490 |
+
self.config = config
|
| 491 |
+
|
| 492 |
+
@classmethod
|
| 493 |
+
def from_pretrained(
|
| 494 |
+
cls,
|
| 495 |
+
model_name,
|
| 496 |
+
config,
|
| 497 |
+
sp_config,
|
| 498 |
+
*args,
|
| 499 |
+
strict=True,
|
| 500 |
+
device=None,
|
| 501 |
+
dtype=None,
|
| 502 |
+
world_size=1,
|
| 503 |
+
rank=0,
|
| 504 |
+
**kwargs,
|
| 505 |
+
):
|
| 506 |
+
"""
|
| 507 |
+
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
| 508 |
+
Download and cache the pre-trained model file if needed.
|
| 509 |
+
"""
|
| 510 |
+
# Instantiate model.
|
| 511 |
+
model = cls(config, sp_config, *args, device=device, dtype=dtype, **kwargs)
|
| 512 |
+
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
|
| 513 |
+
# want extra stuff taking up more GPU memory
|
| 514 |
+
state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
|
| 515 |
+
if model_name.startswith("gpt2"):
|
| 516 |
+
state_dict = remap_state_dict_gpt2(state_dict, config)
|
| 517 |
+
elif model_name.startswith("facebook/opt"):
|
| 518 |
+
state_dict = remap_state_dict_hf_opt(state_dict, config)
|
| 519 |
+
else:
|
| 520 |
+
raise NotImplementedError(f"Model {model_name} not supported")
|
| 521 |
+
if world_size > 1:
|
| 522 |
+
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
|
| 523 |
+
load_return = model.load_state_dict(state_dict, strict=strict)
|
| 524 |
+
# logger.info(load_return)
|
| 525 |
+
return model
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
| 529 |
+
def _init_weights(
|
| 530 |
+
module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True
|
| 531 |
+
):
|
| 532 |
+
if isinstance(module, nn.Linear):
|
| 533 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 534 |
+
if module.bias is not None:
|
| 535 |
+
nn.init.zeros_(module.bias)
|
| 536 |
+
elif isinstance(module, nn.Embedding):
|
| 537 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 538 |
+
|
| 539 |
+
if rescale_prenorm_residual:
|
| 540 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 541 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 542 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 543 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 544 |
+
#
|
| 545 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 546 |
+
for name, p in module.named_parameters():
|
| 547 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
| 548 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 549 |
+
nn.init.normal_(
|
| 550 |
+
p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class GPTModel(GPTPreTrainedModel):
|
| 555 |
+
def __init__(self, config: GPT2Config, sp_config=None, process_group=None, device=None, dtype=None):
|
| 556 |
+
super().__init__(config)
|
| 557 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 558 |
+
self.process_group = process_group
|
| 559 |
+
self.sequence_parallel = getattr(config, "sequence_parallel", True)
|
| 560 |
+
assert config.activation_function in [
|
| 561 |
+
"gelu",
|
| 562 |
+
"gelu_new",
|
| 563 |
+
"gelu_fast",
|
| 564 |
+
"gelu_approx",
|
| 565 |
+
"relu",
|
| 566 |
+
"sqrelu",
|
| 567 |
+
"glu",
|
| 568 |
+
"swiglu",
|
| 569 |
+
"geglu",
|
| 570 |
+
]
|
| 571 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 572 |
+
vocab_size = (
|
| 573 |
+
math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
| 574 |
+
* pad_vocab_size_multiple
|
| 575 |
+
)
|
| 576 |
+
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
| 577 |
+
self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
|
| 578 |
+
# These 2 options are for OPT-350m
|
| 579 |
+
self.prenorm = getattr(config, "prenorm", True)
|
| 580 |
+
use_rms_norm = getattr(config, "rms_norm", False)
|
| 581 |
+
word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
|
| 582 |
+
|
| 583 |
+
if process_group is None:
|
| 584 |
+
self.embeddings = GPT2Embeddings(
|
| 585 |
+
config.hidden_size,
|
| 586 |
+
vocab_size,
|
| 587 |
+
config.max_position_embeddings,
|
| 588 |
+
word_embed_proj_dim=word_embed_proj_dim,
|
| 589 |
+
**factory_kwargs,
|
| 590 |
+
)
|
| 591 |
+
else:
|
| 592 |
+
self.embeddings = ParallelGPT2Embeddings(
|
| 593 |
+
config.hidden_size,
|
| 594 |
+
vocab_size,
|
| 595 |
+
config.max_position_embeddings,
|
| 596 |
+
process_group=process_group,
|
| 597 |
+
sequence_parallel=self.sequence_parallel,
|
| 598 |
+
**factory_kwargs,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
# We change the order of dropout, residual and layer norm:
|
| 603 |
+
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
|
| 604 |
+
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
|
| 605 |
+
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
|
| 606 |
+
# nn.Dropout probabilities are changed.
|
| 607 |
+
# This is for performance reason: we can fuse dropout + add + layer_norm.
|
| 608 |
+
self.layers = nn.ModuleList(
|
| 609 |
+
[
|
| 610 |
+
create_block(
|
| 611 |
+
config, sp_config, layer_idx=i, process_group=process_group, **factory_kwargs
|
| 612 |
+
)
|
| 613 |
+
for i in range(config.num_hidden_layers)
|
| 614 |
+
]
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
| 619 |
+
if self.fused_dropout_add_ln:
|
| 620 |
+
if layer_norm_fn is None:
|
| 621 |
+
raise ImportError("Triton is not installed")
|
| 622 |
+
if self.prenorm:
|
| 623 |
+
self.drop_f = nn.Dropout(config.resid_pdrop)
|
| 624 |
+
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
|
| 625 |
+
# self.ln_f = nn.LayerNorm(
|
| 626 |
+
# config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
|
| 627 |
+
# )
|
| 628 |
+
self.ln_f = norm_cls(
|
| 629 |
+
config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
if process_group is not None:
|
| 634 |
+
for p in self.ln_f.parameters():
|
| 635 |
+
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
| 636 |
+
p._shared_params = True
|
| 637 |
+
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
|
| 638 |
+
if self.sequence_parallel:
|
| 639 |
+
p._sequence_parallel = True
|
| 640 |
+
|
| 641 |
+
self.apply(
|
| 642 |
+
partial(
|
| 643 |
+
_init_weights,
|
| 644 |
+
n_layer=config.num_hidden_layers,
|
| 645 |
+
initializer_range=config.initializer_range,
|
| 646 |
+
)
|
| 647 |
+
)
|
| 648 |
+
self.tie_weights()
|
| 649 |
+
|
| 650 |
+
self.sparse = False
|
| 651 |
+
if config.mlp_sparse or config.att_sparse:
|
| 652 |
+
self.sparse = True
|
| 653 |
+
|
| 654 |
+
def tie_weights(self):
|
| 655 |
+
if self.process_group is not None:
|
| 656 |
+
sync_shared_params(self, self.process_group)
|
| 657 |
+
|
| 658 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 659 |
+
return {
|
| 660 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 661 |
+
for i, layer in enumerate(self.layers)
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
def forward(self, input_ids, position_ids=None, inference_params=None):
|
| 665 |
+
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
|
| 666 |
+
# dimensions so that we can split on it easily, in case of small batch size.
|
| 667 |
+
# Only the attention layers need to know the seqlen.
|
| 668 |
+
embedding_kwargs = (
|
| 669 |
+
{"combine_batch_seqlen_dim": True}
|
| 670 |
+
if self.process_group is not None and self.sequence_parallel
|
| 671 |
+
else {}
|
| 672 |
+
)
|
| 673 |
+
hidden_states = self.embeddings(
|
| 674 |
+
input_ids, position_ids=position_ids, **embedding_kwargs
|
| 675 |
+
)
|
| 676 |
+
residual = None
|
| 677 |
+
mixer_kwargs = (
|
| 678 |
+
{"seqlen": input_ids.shape[1]}
|
| 679 |
+
if self.process_group is not None and self.sequence_parallel
|
| 680 |
+
else {}
|
| 681 |
+
)
|
| 682 |
+
if inference_params is not None:
|
| 683 |
+
mixer_kwargs["inference_params"] = inference_params
|
| 684 |
+
else:
|
| 685 |
+
mixer_kwargs["inference_params"] = None
|
| 686 |
+
|
| 687 |
+
# else:
|
| 688 |
+
for layer in self.layers:
|
| 689 |
+
if self.prenorm:
|
| 690 |
+
hidden_states, residual = layer(
|
| 691 |
+
hidden_states,
|
| 692 |
+
residual,
|
| 693 |
+
mixer_kwargs=mixer_kwargs,
|
| 694 |
+
)
|
| 695 |
+
else:
|
| 696 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 697 |
+
|
| 698 |
+
if self.prenorm:
|
| 699 |
+
if not self.fused_dropout_add_ln:
|
| 700 |
+
dropped = self.drop_f(hidden_states)
|
| 701 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 702 |
+
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
|
| 703 |
+
else:
|
| 704 |
+
# Set prenorm=False here since we don't need the residual
|
| 705 |
+
if hidden_states.shape != residual.shape:
|
| 706 |
+
hidden_states = hidden_states.view(residual.shape)
|
| 707 |
+
|
| 708 |
+
hidden_states = layer_norm_fn(
|
| 709 |
+
hidden_states,
|
| 710 |
+
self.ln_f.weight,
|
| 711 |
+
self.ln_f.bias,
|
| 712 |
+
residual=residual,
|
| 713 |
+
x1=None,
|
| 714 |
+
eps=self.ln_f.eps,
|
| 715 |
+
dropout_p=self.drop_f.p if self.training else 0.0,
|
| 716 |
+
prenorm=False,
|
| 717 |
+
is_rms_norm=isinstance(self.ln_f, RMSNorm)
|
| 718 |
+
)
|
| 719 |
+
return hidden_states
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
| 723 |
+
def __init__(self, config: GPT2Config, sp_config = None, process_group=None, device=None, dtype=None):
|
| 724 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 725 |
+
super().__init__(config)
|
| 726 |
+
self.process_group = process_group
|
| 727 |
+
|
| 728 |
+
self.transformer = GPTModel(
|
| 729 |
+
config, sp_config, process_group=process_group, **factory_kwargs
|
| 730 |
+
)
|
| 731 |
+
self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
|
| 732 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 733 |
+
vocab_size = (
|
| 734 |
+
math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
| 735 |
+
* pad_vocab_size_multiple
|
| 736 |
+
)
|
| 737 |
+
# This option is for OPT-350m
|
| 738 |
+
word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
|
| 739 |
+
embed_dim = (
|
| 740 |
+
config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
|
| 741 |
+
)
|
| 742 |
+
if word_embed_proj_dim is not None:
|
| 743 |
+
self.project_out = nn.Linear(
|
| 744 |
+
config.n_embd, embed_dim, bias=False, **factory_kwargs
|
| 745 |
+
)
|
| 746 |
+
else:
|
| 747 |
+
self.project_out = None
|
| 748 |
+
mup_width_scale = getattr(config, "mup_width_scale", 1.0)
|
| 749 |
+
mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0)
|
| 750 |
+
self.output_scale = mup_output_multiplier * mup_width_scale
|
| 751 |
+
|
| 752 |
+
if process_group is None:
|
| 753 |
+
self.lm_head = nn.Linear(
|
| 754 |
+
embed_dim, vocab_size, bias=False, **factory_kwargs
|
| 755 |
+
)
|
| 756 |
+
else:
|
| 757 |
+
if ColumnParallelLinear is None:
|
| 758 |
+
raise ImportError("fused_dense_lib is not installed")
|
| 759 |
+
self.lm_head = ColumnParallelLinear(
|
| 760 |
+
embed_dim,
|
| 761 |
+
vocab_size,
|
| 762 |
+
process_group,
|
| 763 |
+
bias=False,
|
| 764 |
+
sequence_parallel=getattr(config, "sequence_parallel", True),
|
| 765 |
+
**factory_kwargs,
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
self.norm_head = getattr(config, "norm_head", False)
|
| 769 |
+
# Initialize weights and apply final processing
|
| 770 |
+
self.apply(
|
| 771 |
+
partial(
|
| 772 |
+
_init_weights,
|
| 773 |
+
n_layer=config.num_hidden_layers,
|
| 774 |
+
initializer_range=config.initializer_range,
|
| 775 |
+
)
|
| 776 |
+
)
|
| 777 |
+
self.tie_weights()
|
| 778 |
+
|
| 779 |
+
def tie_weights(self):
|
| 780 |
+
if self.tie_word_embeddings:
|
| 781 |
+
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight # llama does not use tied weights
|
| 782 |
+
if self.process_group is not None:
|
| 783 |
+
sync_shared_params(self, self.process_group)
|
| 784 |
+
|
| 785 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 786 |
+
return self.transformer.allocate_inference_cache(
|
| 787 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
|
| 791 |
+
"""
|
| 792 |
+
input_ids: (batch, seqlen) int tensor
|
| 793 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
| 794 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
| 795 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 796 |
+
"""
|
| 797 |
+
assert (
|
| 798 |
+
input_ids.ndim == 2
|
| 799 |
+
), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
|
| 800 |
+
b, slen = input_ids.shape
|
| 801 |
+
hidden_states = self.transformer(
|
| 802 |
+
input_ids, position_ids=position_ids, inference_params=inference_params
|
| 803 |
+
)
|
| 804 |
+
if inference_params is not None:
|
| 805 |
+
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
|
| 806 |
+
if num_last_tokens > 0:
|
| 807 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 808 |
+
if self.project_out is not None:
|
| 809 |
+
hidden_states = self.project_out(hidden_states)
|
| 810 |
+
if self.output_scale != 1.0:
|
| 811 |
+
hidden_states = hidden_states * self.output_scale
|
| 812 |
+
if not self.norm_head:
|
| 813 |
+
lm_logits = self.lm_head(hidden_states)
|
| 814 |
+
else:
|
| 815 |
+
lm_head_weight = F.normalize(self.lm_head.weight)
|
| 816 |
+
if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:
|
| 817 |
+
hidden_states = all_gather(hidden_states, self.lm_head.process_group)
|
| 818 |
+
lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)
|
| 819 |
+
# During inference, we want the full logit for sampling
|
| 820 |
+
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
|
| 821 |
+
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
|
| 822 |
+
lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
|
| 823 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
| 824 |
+
return CausalLMOutput(logits=lm_logits)
|
| 825 |
+
|
| 826 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 827 |
+
# Remapping from our checkpoints that used a different ordering of layers in the block
|
| 828 |
+
# Previous: Attn / MLP -> Dropout -> Add -> LN
|
| 829 |
+
# Current: Dropout -> Add -> LN -> Attn / MLP
|
| 830 |
+
if "transformer.ln_0.weight" in state_dict:
|
| 831 |
+
n_layers = len(self.transformer.layers)
|
| 832 |
+
ln_weight = state_dict.pop(
|
| 833 |
+
f"transformer.layers.{n_layers - 1}.norm2.weight"
|
| 834 |
+
)
|
| 835 |
+
ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias")
|
| 836 |
+
state_dict["transformer.ln_f.weight"] = ln_weight
|
| 837 |
+
state_dict["transformer.ln_f.bias"] = ln_bias
|
| 838 |
+
for l in reversed(range(n_layers)):
|
| 839 |
+
ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight")
|
| 840 |
+
ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias")
|
| 841 |
+
state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight
|
| 842 |
+
state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias
|
| 843 |
+
if l > 0:
|
| 844 |
+
ln_weight = state_dict.pop(
|
| 845 |
+
f"transformer.layers.{l - 1}.norm2.weight"
|
| 846 |
+
)
|
| 847 |
+
ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias")
|
| 848 |
+
state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight
|
| 849 |
+
state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias
|
| 850 |
+
ln_weight = state_dict.pop("transformer.ln_0.weight")
|
| 851 |
+
ln_bias = state_dict.pop("transformer.ln_0.bias")
|
| 852 |
+
state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
|
| 853 |
+
state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
|
| 854 |
+
return super().load_state_dict(state_dict, strict=strict)
|
HybridTensor/models/helper.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import re
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def remap_state_dict_gpt2(state_dict, config):
|
| 9 |
+
# Word embedding and position embedding
|
| 10 |
+
def key_mapping_pos_emb(key):
|
| 11 |
+
return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
|
| 12 |
+
|
| 13 |
+
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
| 14 |
+
word_embeddings = state_dict.pop("wte.weight")
|
| 15 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 16 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 17 |
+
vocab_size = (
|
| 18 |
+
math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 19 |
+
)
|
| 20 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 21 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 22 |
+
)
|
| 23 |
+
state_dict["lm_head.weight"] = state_dict[
|
| 24 |
+
"transformer.embeddings.word_embeddings.weight"
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
# LayerNorm
|
| 28 |
+
def key_mapping_ln(key):
|
| 29 |
+
key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
| 30 |
+
key = re.sub(
|
| 31 |
+
r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key
|
| 32 |
+
)
|
| 33 |
+
return key
|
| 34 |
+
|
| 35 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 36 |
+
|
| 37 |
+
# MLP
|
| 38 |
+
for d in range(config.num_hidden_layers):
|
| 39 |
+
W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight")
|
| 40 |
+
state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t()
|
| 41 |
+
W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight")
|
| 42 |
+
state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
|
| 43 |
+
|
| 44 |
+
def key_mapping_mlp(key):
|
| 45 |
+
key = re.sub(
|
| 46 |
+
r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key
|
| 47 |
+
)
|
| 48 |
+
key = re.sub(
|
| 49 |
+
r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key
|
| 50 |
+
)
|
| 51 |
+
return key
|
| 52 |
+
|
| 53 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 54 |
+
|
| 55 |
+
# Attention
|
| 56 |
+
for d in range(config.num_hidden_layers):
|
| 57 |
+
state_dict.pop(f"h.{d}.attn.bias") # We don't store this bias
|
| 58 |
+
Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
|
| 59 |
+
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
|
| 60 |
+
Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
|
| 61 |
+
state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
|
| 62 |
+
|
| 63 |
+
def key_mapping_attn(key):
|
| 64 |
+
key = re.sub(
|
| 65 |
+
r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key
|
| 66 |
+
)
|
| 67 |
+
key = re.sub(
|
| 68 |
+
r"^h.(\d+).attn.c_proj.bias",
|
| 69 |
+
r"transformer.layers.\1.mixer.out_proj.bias",
|
| 70 |
+
key,
|
| 71 |
+
)
|
| 72 |
+
return key
|
| 73 |
+
|
| 74 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 75 |
+
|
| 76 |
+
return state_dict
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def shard_state_dict_tp(state_dict, config, world_size, rank):
|
| 80 |
+
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
| 81 |
+
with tensor parallel.
|
| 82 |
+
"""
|
| 83 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 84 |
+
vocab_size = (
|
| 85 |
+
math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 86 |
+
)
|
| 87 |
+
assert vocab_size % world_size == 0
|
| 88 |
+
assert config.hidden_size % world_size == 0
|
| 89 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
| 90 |
+
assert inner_dim % world_size == 0
|
| 91 |
+
|
| 92 |
+
def shard_first_dim(state_dict, key):
|
| 93 |
+
x = state_dict[key]
|
| 94 |
+
dim = x.shape[0] // world_size
|
| 95 |
+
state_dict[key] = x[rank * dim : (rank + 1) * dim]
|
| 96 |
+
|
| 97 |
+
def shard_last_dim(state_dict, key):
|
| 98 |
+
x = state_dict[key]
|
| 99 |
+
dim = x.shape[-1] // world_size
|
| 100 |
+
state_dict[key] = x[..., rank * dim : (rank + 1) * dim]
|
| 101 |
+
|
| 102 |
+
def shard_qkv_headdim(state_dict, key):
|
| 103 |
+
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
|
| 104 |
+
dim = x.shape[1] // world_size
|
| 105 |
+
state_dict[key] = rearrange(
|
| 106 |
+
x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
|
| 110 |
+
if "lm_head.weight" in state_dict:
|
| 111 |
+
shard_first_dim(state_dict, "lm_head.weight")
|
| 112 |
+
if "transformer.embeddings.position_embeddings.weight" in state_dict:
|
| 113 |
+
shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
|
| 114 |
+
for i in range(config.num_hidden_layers):
|
| 115 |
+
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
|
| 116 |
+
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
|
| 117 |
+
shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight")
|
| 118 |
+
if rank != 0:
|
| 119 |
+
state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias")
|
| 120 |
+
shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
|
| 121 |
+
shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
|
| 122 |
+
shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
|
| 123 |
+
if rank != 0:
|
| 124 |
+
state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias")
|
| 125 |
+
return state_dict
|
HybridTensor/models/llama.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import LlamaConfig, LlamaTokenizer
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from HybridTensor.models.create_sparse_model import GPTLMHeadModel
|
| 6 |
+
from HybridTensor.modules.SelectiveRouters import create_mlp_router_state_dict, create_attn_router_state_dict
|
| 7 |
+
|
| 8 |
+
# from flash_attn.models.gpt import GPTLMHeadModel
|
| 9 |
+
from transformers import AutoConfig, AutoTokenizer
|
| 10 |
+
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
| 11 |
+
|
| 12 |
+
from flash_attn.models.llama import (
|
| 13 |
+
config_from_checkpoint,
|
| 14 |
+
inv_remap_state_dict_hf_llama,
|
| 15 |
+
llama_config_to_gpt2_config,
|
| 16 |
+
remap_state_dict_hf_llama,
|
| 17 |
+
remap_state_dict_meta_llama,
|
| 18 |
+
state_dicts_from_checkpoint,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
class SparseConfig:
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.mlp_low_rank_dim = 1024
|
| 24 |
+
self.attn_low_rank_dim = 128
|
| 25 |
+
self.mlp_act_th = 0.5
|
| 26 |
+
self.attn_topk = 0.3
|
| 27 |
+
|
| 28 |
+
def build_dense_llama(model_name: str, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None, **kwargs):
|
| 29 |
+
config = llama_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, trust_remote_code=True))
|
| 30 |
+
config.use_flash_attn = True
|
| 31 |
+
config.fused_bias_fc = True
|
| 32 |
+
config.fused_mlp = False # We don't have fused GatedMLP yet
|
| 33 |
+
config.fused_dropout_add_ln = True
|
| 34 |
+
config.residual_in_fp32 = True
|
| 35 |
+
config.prenorm = True
|
| 36 |
+
|
| 37 |
+
state_dict = state_dict_from_pretrained(model_name, device='cpu', dtype=dtype)
|
| 38 |
+
state_dict = remap_state_dict_hf_llama(state_dict, config)
|
| 39 |
+
|
| 40 |
+
model = GPTLMHeadModel(config, device=device, dtype=dtype)
|
| 41 |
+
model.load_state_dict(state_dict, strict=True)
|
| 42 |
+
model.eval()
|
| 43 |
+
|
| 44 |
+
return model
|
| 45 |
+
|
| 46 |
+
def build_sparse_llama(args, model_name: str, attn_ckpt_dir: str, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None, **kwargs):
|
| 47 |
+
config = llama_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, trust_remote_code=True))
|
| 48 |
+
config.use_flash_attn = True
|
| 49 |
+
config.fused_bias_fc = True
|
| 50 |
+
config.fused_mlp = False # We don't have fused GatedMLP yet
|
| 51 |
+
config.fused_dropout_add_ln = True
|
| 52 |
+
config.residual_in_fp32 = True
|
| 53 |
+
config.prenorm = True
|
| 54 |
+
|
| 55 |
+
spconfig = SparseConfig()
|
| 56 |
+
spconfig.attn_topk = args.attn_topk
|
| 57 |
+
config.mlp_sparse = False
|
| 58 |
+
config.att_sparse = True
|
| 59 |
+
|
| 60 |
+
state_dict = state_dict_from_pretrained(model_name, device='cpu', dtype=dtype)
|
| 61 |
+
state_dict = remap_state_dict_hf_llama(state_dict, config)
|
| 62 |
+
|
| 63 |
+
model = GPTLMHeadModel(config, sp_config= spconfig, device=device, dtype=dtype)
|
| 64 |
+
|
| 65 |
+
if attn_ckpt_dir is not None:
|
| 66 |
+
attn_router_state_dict = create_attn_router_state_dict(attn_ckpt_dir)
|
| 67 |
+
merged_state_dict = {**state_dict, **attn_router_state_dict}
|
| 68 |
+
|
| 69 |
+
# TODO: Add code for tensor parallel state dict sharding
|
| 70 |
+
|
| 71 |
+
model.load_state_dict(merged_state_dict, strict=True)
|
| 72 |
+
model.eval()
|
| 73 |
+
|
| 74 |
+
return model
|
HybridTensor/models/opt.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from HybridTensor.utils.activations import OPT_MODELS
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
| 7 |
+
from flash_attn.models.opt import remap_state_dict_hf_opt
|
| 8 |
+
from HybridTensor.modules.SelectiveRouters import create_mlp_router_state_dict, create_attn_router_state_dict
|
| 9 |
+
from HybridTensor.models.create_sparse_model import GPTLMHeadModel as GPTLMHeadModelSparse
|
| 10 |
+
from flash_attn.models.gpt import GPTLMHeadModel
|
| 11 |
+
|
| 12 |
+
from transformers.models.opt import OPTConfig
|
| 13 |
+
from flash_attn.models.opt import opt_config_to_gpt2_config
|
| 14 |
+
|
| 15 |
+
class SparseConfig:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.mlp_low_rank_dim = 1024
|
| 18 |
+
self.attn_low_rank_dim = 128
|
| 19 |
+
self.mlp_act_th = 0.5
|
| 20 |
+
self.attn_topk = 0.3
|
| 21 |
+
|
| 22 |
+
def shard_state_dict_tp(state_dict, config, world_size, rank):
|
| 23 |
+
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
| 24 |
+
with tensor parallel.
|
| 25 |
+
"""
|
| 26 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 27 |
+
vocab_size = (
|
| 28 |
+
math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 29 |
+
)
|
| 30 |
+
assert vocab_size % world_size == 0
|
| 31 |
+
assert config.hidden_size % world_size == 0
|
| 32 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
| 33 |
+
assert inner_dim % world_size == 0
|
| 34 |
+
|
| 35 |
+
shared_state_dict = {}
|
| 36 |
+
|
| 37 |
+
def shard_first_dim(new, old, key):
|
| 38 |
+
x = old[key]
|
| 39 |
+
dim = x.shape[0] // world_size
|
| 40 |
+
new[key] = x[rank * dim : (rank + 1) * dim]
|
| 41 |
+
|
| 42 |
+
def shard_last_dim(new, old, key):
|
| 43 |
+
x = old[key]
|
| 44 |
+
dim = x.shape[-1] // world_size
|
| 45 |
+
new[key] = x[..., rank * dim : (rank + 1) * dim]
|
| 46 |
+
|
| 47 |
+
def shard_qkv_headdim(new, old, key):
|
| 48 |
+
x = rearrange(old[key], "(three d) ... -> three d ...", three=3)
|
| 49 |
+
dim = x.shape[1] // world_size
|
| 50 |
+
new[key] = rearrange(
|
| 51 |
+
x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
shard_first_dim(shared_state_dict, state_dict, "transformer.embeddings.word_embeddings.weight")
|
| 55 |
+
|
| 56 |
+
if "lm_head.weight" in state_dict:
|
| 57 |
+
shard_first_dim(shared_state_dict, state_dict, "lm_head.weight")
|
| 58 |
+
if "transformer.embeddings.position_embeddings.weight" in state_dict:
|
| 59 |
+
shard_last_dim(shared_state_dict, state_dict, "transformer.embeddings.position_embeddings.weight")
|
| 60 |
+
|
| 61 |
+
for i in range(config.num_hidden_layers):
|
| 62 |
+
# attention
|
| 63 |
+
shard_qkv_headdim(shared_state_dict, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
|
| 64 |
+
shard_qkv_headdim(shared_state_dict, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
|
| 65 |
+
shard_last_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight")
|
| 66 |
+
|
| 67 |
+
# mlp
|
| 68 |
+
shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
|
| 69 |
+
shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
|
| 70 |
+
shard_last_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
|
| 71 |
+
|
| 72 |
+
if rank == 0:
|
| 73 |
+
shared_state_dict[f"transformer.layers.{i}.mlp.fc2.bias"] = state_dict[f"transformer.layers.{i}.mlp.fc2.bias"]
|
| 74 |
+
shared_state_dict[f"transformer.layers.{i}.mixer.out_proj.bias"] = state_dict[f"transformer.layers.{i}.mixer.out_proj.bias"]
|
| 75 |
+
|
| 76 |
+
shared_state_dict[f"transformer.layers.{i}.norm1.weight"] = state_dict[f"transformer.layers.{i}.norm1.weight"]
|
| 77 |
+
shared_state_dict[f"transformer.layers.{i}.norm1.bias"] = state_dict[f"transformer.layers.{i}.norm1.bias"]
|
| 78 |
+
shared_state_dict[f"transformer.layers.{i}.norm2.weight"] = state_dict[f"transformer.layers.{i}.norm2.weight"]
|
| 79 |
+
shared_state_dict[f"transformer.layers.{i}.norm2.bias"] = state_dict[f"transformer.layers.{i}.norm2.bias"]
|
| 80 |
+
|
| 81 |
+
# routers
|
| 82 |
+
|
| 83 |
+
# mlp router
|
| 84 |
+
shared_state_dict[f"transformer.layers.{i}.mlp_router.fc1.weight"] = state_dict[f"transformer.layers.{i}.mlp_router.fc1.weight"]
|
| 85 |
+
shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp_router.fc2.weight")
|
| 86 |
+
|
| 87 |
+
# mha router
|
| 88 |
+
shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mha_router.linear1.weight")
|
| 89 |
+
shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mha_router.linear1.bias")
|
| 90 |
+
|
| 91 |
+
shared_state_dict[f"transformer.ln_f.weight"] = state_dict["transformer.ln_f.weight"]
|
| 92 |
+
shared_state_dict[f"transformer.ln_f.bias"] = state_dict["transformer.ln_f.bias"]
|
| 93 |
+
|
| 94 |
+
# shared_state_dict[f"transformer.ln_f.weight"] = state_dict["transformer.final_layer_norm.weight"]
|
| 95 |
+
# shared_state_dict[f"transformer.ln_f.bias"] = state_dict["transformer.final_layer_norm.bias"]
|
| 96 |
+
|
| 97 |
+
return shared_state_dict
|
| 98 |
+
|
| 99 |
+
'''
|
| 100 |
+
def shard_state_dict_tp(state_dict, config, world_size, rank):
|
| 101 |
+
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
| 102 |
+
with tensor parallel.
|
| 103 |
+
"""
|
| 104 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 105 |
+
vocab_size = (
|
| 106 |
+
math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 107 |
+
)
|
| 108 |
+
assert vocab_size % world_size == 0
|
| 109 |
+
assert config.hidden_size % world_size == 0
|
| 110 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
| 111 |
+
assert inner_dim % world_size == 0
|
| 112 |
+
|
| 113 |
+
def shard_first_dim(state_dict, key):
|
| 114 |
+
x = state_dict[key]
|
| 115 |
+
dim = x.shape[0] // world_size
|
| 116 |
+
state_dict[key] = x[rank * dim : (rank + 1) * dim]
|
| 117 |
+
|
| 118 |
+
def shard_last_dim(state_dict, key):
|
| 119 |
+
x = state_dict[key]
|
| 120 |
+
dim = x.shape[-1] // world_size
|
| 121 |
+
state_dict[key] = x[..., rank * dim : (rank + 1) * dim]
|
| 122 |
+
|
| 123 |
+
def shard_qkv_headdim(state_dict, key):
|
| 124 |
+
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
|
| 125 |
+
dim = x.shape[1] // world_size
|
| 126 |
+
state_dict[key] = rearrange(
|
| 127 |
+
x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..."
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
|
| 131 |
+
if "lm_head.weight" in state_dict:
|
| 132 |
+
shard_first_dim(state_dict, "lm_head.weight")
|
| 133 |
+
if "transformer.embeddings.position_embeddings.weight" in state_dict:
|
| 134 |
+
shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
|
| 135 |
+
for i in range(config.num_hidden_layers):
|
| 136 |
+
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
|
| 137 |
+
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
|
| 138 |
+
shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight")
|
| 139 |
+
if rank != 0:
|
| 140 |
+
state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias")
|
| 141 |
+
shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
|
| 142 |
+
shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
|
| 143 |
+
shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
|
| 144 |
+
if rank != 0:
|
| 145 |
+
state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias")
|
| 146 |
+
return state_dict
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
'''
|
| 150 |
+
|
| 151 |
+
def build_sparse_opt(args, model_name, mlp_ckpt_dir, attn_ckpt_dir, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None):
|
| 152 |
+
# dtype = torch.float16
|
| 153 |
+
|
| 154 |
+
config = OPTConfig.from_pretrained(model_name)
|
| 155 |
+
config = opt_config_to_gpt2_config(config)
|
| 156 |
+
|
| 157 |
+
if device in ('cpu', torch.device('cpu')):
|
| 158 |
+
config.fused_mlp = False
|
| 159 |
+
config.fused_dropout_add_ln = False
|
| 160 |
+
config.use_flash_attn = False
|
| 161 |
+
config.fused_bias_fc = False
|
| 162 |
+
else:
|
| 163 |
+
config.fused_mlp = True
|
| 164 |
+
config.fused_dropout_add_ln = True
|
| 165 |
+
config.use_flash_attn = True
|
| 166 |
+
config.fused_bias_fc = True
|
| 167 |
+
config.sequence_parallel = False
|
| 168 |
+
|
| 169 |
+
config.residual_in_fp32 = getattr(config, "prenorm", True)
|
| 170 |
+
config.pad_vocab_size_multiple = 8
|
| 171 |
+
config.mlp_sparse = True
|
| 172 |
+
config.att_sparse = True
|
| 173 |
+
|
| 174 |
+
config.use_heuristic = True
|
| 175 |
+
if config.use_heuristic:
|
| 176 |
+
print("Using pre-compiled heuristic")
|
| 177 |
+
else:
|
| 178 |
+
print("Compiling new heuristic during runtime")
|
| 179 |
+
|
| 180 |
+
spconfig = SparseConfig()
|
| 181 |
+
spconfig.mlp_act_th = 0.5 # sets the threshold for the MLP routers for all layers
|
| 182 |
+
spconfig.attn_topk = args.attn_topk # sets the topk for the attention routers for all layers
|
| 183 |
+
|
| 184 |
+
# build model
|
| 185 |
+
print("Bulding Model with sparse routers")
|
| 186 |
+
model_sparse = GPTLMHeadModelSparse(config = config, sp_config = spconfig, process_group = process_group, device = device, dtype=dtype)
|
| 187 |
+
# print(model_sparse)
|
| 188 |
+
|
| 189 |
+
# load pretrained weights into the sparse model
|
| 190 |
+
state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
|
| 191 |
+
state_dict = remap_state_dict_hf_opt(state_dict, config)
|
| 192 |
+
|
| 193 |
+
# load the routers into the model
|
| 194 |
+
if mlp_ckpt_dir is not None and attn_ckpt_dir is not None:
|
| 195 |
+
mlp_router_state_dict = create_mlp_router_state_dict(mlp_ckpt_dir)
|
| 196 |
+
attn_router_state_dict = create_attn_router_state_dict(attn_ckpt_dir)
|
| 197 |
+
|
| 198 |
+
# merge the state dict
|
| 199 |
+
merged_state_dict = {**state_dict, **mlp_router_state_dict, **attn_router_state_dict}
|
| 200 |
+
|
| 201 |
+
if process_group is not None:
|
| 202 |
+
merged_state_dict = shard_state_dict_tp(merged_state_dict, config, world_size, rank)
|
| 203 |
+
|
| 204 |
+
model_sparse.load_state_dict(merged_state_dict, strict=True)
|
| 205 |
+
else:
|
| 206 |
+
if process_group is not None:
|
| 207 |
+
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
|
| 208 |
+
model_sparse.load_state_dict(state_dict, strict=False)
|
| 209 |
+
|
| 210 |
+
return model_sparse
|
| 211 |
+
|
| 212 |
+
def build_dense_opt(model_name, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None):
|
| 213 |
+
dtype = torch.float16
|
| 214 |
+
|
| 215 |
+
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
|
| 216 |
+
config.use_flash_attn = True
|
| 217 |
+
config.fused_bias_fc = True
|
| 218 |
+
config.fused_mlp = True
|
| 219 |
+
# config.fused_dropout_add_ln = True
|
| 220 |
+
config.sequence_parallel = False
|
| 221 |
+
# Only prenorm supports residual_in_fp32
|
| 222 |
+
config.residual_in_fp32 = getattr(config, "prenorm", True)
|
| 223 |
+
config.pad_vocab_size_multiple = 8
|
| 224 |
+
|
| 225 |
+
# build model
|
| 226 |
+
print("Bulding Dense Model")
|
| 227 |
+
model = GPTLMHeadModel.from_pretrained(model_name, config, process_group = process_group, world_size = world_size, rank = rank, device=device, dtype=dtype)
|
| 228 |
+
|
| 229 |
+
return model
|
HybridTensor/modules/SelectiveBlock.py
ADDED
|
@@ -0,0 +1,960 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torchvision.ops import StochasticDepth
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
| 12 |
+
except ImportError:
|
| 13 |
+
layer_norm_fn, RMSNorm = None, None
|
| 14 |
+
|
| 15 |
+
class SelectBlock(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim,
|
| 19 |
+
mixer_cls=None,
|
| 20 |
+
mlp_cls=None,
|
| 21 |
+
mlp_router=None,
|
| 22 |
+
mha_router=None,
|
| 23 |
+
norm_cls=nn.LayerNorm,
|
| 24 |
+
dropout_cls=nn.Dropout,
|
| 25 |
+
prenorm=True,
|
| 26 |
+
resid_dropout1=0.0,
|
| 27 |
+
resid_dropout2=0.0,
|
| 28 |
+
drop_path1=0.0,
|
| 29 |
+
drop_path2=0.0,
|
| 30 |
+
fused_dropout_add_ln=False,
|
| 31 |
+
return_residual=False,
|
| 32 |
+
residual_in_fp32=False,
|
| 33 |
+
sequence_parallel=False,
|
| 34 |
+
mark_shared_params=False,
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
For prenorm=True, this Block has a slightly different structure compared to a regular
|
| 38 |
+
prenorm Transformer block.
|
| 39 |
+
|
| 40 |
+
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
|
| 41 |
+
Here we do: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, etc.
|
| 42 |
+
|
| 43 |
+
If you want to do concurrency with CUDA graphs, your shapes must remain fixed
|
| 44 |
+
(batch_size, seq_len, etc.) across captures and replays. Also avoid any operations
|
| 45 |
+
that cause dynamic shape changes or memory allocations.
|
| 46 |
+
"""
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.prenorm = prenorm
|
| 49 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
| 50 |
+
self.return_residual = return_residual
|
| 51 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 52 |
+
if self.residual_in_fp32:
|
| 53 |
+
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
|
| 54 |
+
|
| 55 |
+
assert mixer_cls is not None and mlp_cls is not None, (
|
| 56 |
+
"mixer_cls and mlp_cls cannot be None in SelectBlock"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# MHA & MLP submodules
|
| 60 |
+
self.mixer = mixer_cls(dim)
|
| 61 |
+
self.dropout1 = dropout_cls(resid_dropout1)
|
| 62 |
+
self.drop_path1 = StochasticDepth(drop_path1, mode="row")
|
| 63 |
+
self.norm1 = norm_cls(dim)
|
| 64 |
+
self.mlp = mlp_cls(dim)
|
| 65 |
+
self.total_neurons = self.mlp.fc1.weight.shape[0]
|
| 66 |
+
|
| 67 |
+
# Routers
|
| 68 |
+
if mlp_router is not None:
|
| 69 |
+
self.mlp_router = mlp_router(dim)
|
| 70 |
+
self.skip_attn_router = False
|
| 71 |
+
else:
|
| 72 |
+
self.mlp_router = None
|
| 73 |
+
self.skip_attn_router = True
|
| 74 |
+
|
| 75 |
+
if mha_router is not None:
|
| 76 |
+
self.mha_router = mha_router(dim)
|
| 77 |
+
else:
|
| 78 |
+
self.mha_router = None
|
| 79 |
+
|
| 80 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 81 |
+
self.dropout2 = dropout_cls(resid_dropout2)
|
| 82 |
+
self.drop_path2 = StochasticDepth(drop_path2, mode="row")
|
| 83 |
+
self.norm2 = norm_cls(dim)
|
| 84 |
+
|
| 85 |
+
if self.fused_dropout_add_ln:
|
| 86 |
+
assert layer_norm_fn is not None, "Triton layer_norm_fn not installed"
|
| 87 |
+
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(self.dropout1, nn.Dropout)
|
| 88 |
+
|
| 89 |
+
# Mark the norm parameters for sequence parallel / shared params if needed
|
| 90 |
+
if sequence_parallel:
|
| 91 |
+
for p in self.norm1.parameters():
|
| 92 |
+
p._sequence_parallel = True
|
| 93 |
+
if hasattr(self, "norm2"):
|
| 94 |
+
for p in self.norm2.parameters():
|
| 95 |
+
p._sequence_parallel = True
|
| 96 |
+
if mark_shared_params:
|
| 97 |
+
for p in self.norm1.parameters():
|
| 98 |
+
p._shared_params = True
|
| 99 |
+
if hasattr(self, "norm2"):
|
| 100 |
+
for p in self.norm2.parameters():
|
| 101 |
+
p._shared_params = True
|
| 102 |
+
|
| 103 |
+
self.mlp_topk = None
|
| 104 |
+
self.skip_mlp_router = False
|
| 105 |
+
self.skip_attn_router = False
|
| 106 |
+
|
| 107 |
+
# We'll use an extra stream for concurrency
|
| 108 |
+
self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0)
|
| 109 |
+
self.main_stream = torch.cuda.Stream(device="cuda", priority=-5)
|
| 110 |
+
# We'll record events to coordinate concurrency
|
| 111 |
+
self.mha_event = torch.cuda.Event(enable_timing=False, blocking=False)
|
| 112 |
+
self.mlp_event = torch.cuda.Event(enable_timing=False, blocking=False)
|
| 113 |
+
|
| 114 |
+
self.use_tensor_parallel = mark_shared_params
|
| 115 |
+
|
| 116 |
+
if self.use_tensor_parallel:
|
| 117 |
+
# save the stream and events in the mixer and mlp classes
|
| 118 |
+
self.mlp.router = self.mlp_router
|
| 119 |
+
self.mixer.router = self.mha_router
|
| 120 |
+
|
| 121 |
+
self.mlp_topk_layers = None # this will be a dictionary of layer_idx -> topk value
|
| 122 |
+
self.attn_topk_layers = None # this will be a dictionary of layer_idx -> topk value
|
| 123 |
+
|
| 124 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 125 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 126 |
+
|
| 127 |
+
def prefill_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_kwargs=None, mixer_subset=None):
|
| 128 |
+
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
| 129 |
+
|
| 130 |
+
if mixer_subset is not None:
|
| 131 |
+
residual = residual[:, mixer_subset]
|
| 132 |
+
|
| 133 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 134 |
+
if not self.fused_dropout_add_ln:
|
| 135 |
+
dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 136 |
+
if dropped.shape != residual.shape:
|
| 137 |
+
dropped = dropped.view(residual.shape)
|
| 138 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 139 |
+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
| 140 |
+
if self.residual_in_fp32:
|
| 141 |
+
residual = residual.to(torch.float32)
|
| 142 |
+
else:
|
| 143 |
+
if self.drop_path2.p == 0 or not self.training:
|
| 144 |
+
rowscale2 = None
|
| 145 |
+
else:
|
| 146 |
+
rowscale2 = self.drop_path2(
|
| 147 |
+
torch.ones(
|
| 148 |
+
hidden_states.shape[:-1],
|
| 149 |
+
device=hidden_states.device,
|
| 150 |
+
dtype=hidden_states.dtype,
|
| 151 |
+
)
|
| 152 |
+
)
|
| 153 |
+
if hidden_states.shape != residual.shape:
|
| 154 |
+
hidden_states = hidden_states.view(residual.shape)
|
| 155 |
+
hidden_states, residual = layer_norm_fn(
|
| 156 |
+
hidden_states,
|
| 157 |
+
self.norm2.weight,
|
| 158 |
+
self.norm2.bias,
|
| 159 |
+
residual=residual,
|
| 160 |
+
eps=self.norm2.eps,
|
| 161 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
| 162 |
+
rowscale=rowscale2,
|
| 163 |
+
prenorm=True,
|
| 164 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 165 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 166 |
+
)
|
| 167 |
+
hidden_states = self.mlp(hidden_states)
|
| 168 |
+
return hidden_states, residual
|
| 169 |
+
|
| 170 |
+
def decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
|
| 171 |
+
""" Single GPU Decode Forward
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
hidden_states (Tensor): _description_
|
| 175 |
+
residual (Optional[Tensor], optional): _description_. Defaults to None.
|
| 176 |
+
mixer_subset (_type_, optional): _description_. Defaults to None.
|
| 177 |
+
"""
|
| 178 |
+
curr_stream = torch.cuda.current_stream()
|
| 179 |
+
|
| 180 |
+
# We want to run MHA & mlp_router in parallel on different streams
|
| 181 |
+
router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
|
| 182 |
+
self.main_stream.wait_stream(curr_stream)
|
| 183 |
+
self.sparse_stream.wait_stream(curr_stream)
|
| 184 |
+
main_stream = self.main_stream
|
| 185 |
+
|
| 186 |
+
# if mlp_topk > th * total_neurons, skip mlp router
|
| 187 |
+
|
| 188 |
+
# if self.mlp_topk > 0.8 * self.total_neurons:
|
| 189 |
+
# self.skip_mlp_router = True
|
| 190 |
+
# else:
|
| 191 |
+
# self.skip_mlp_router = False
|
| 192 |
+
|
| 193 |
+
# [Sparse stream] mlp_router
|
| 194 |
+
if not self.skip_mlp_router:
|
| 195 |
+
with torch.cuda.stream(self.sparse_stream):
|
| 196 |
+
index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
|
| 197 |
+
self.sparse_stream.record_event(self.mlp_event)
|
| 198 |
+
|
| 199 |
+
# [Main stream] MHA
|
| 200 |
+
with torch.cuda.stream(main_stream):
|
| 201 |
+
batch_head_idx = self.mha_router._select_heads(router_inputs)
|
| 202 |
+
hidden_states = self.mixer(
|
| 203 |
+
hidden_states,
|
| 204 |
+
batch_head_idx=batch_head_idx,
|
| 205 |
+
**mixer_kwargs
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
main_stream.record_event(self.mha_event)
|
| 209 |
+
|
| 210 |
+
# Now we unify after both are done, then do the next steps
|
| 211 |
+
with torch.cuda.stream(main_stream):
|
| 212 |
+
# Wait on router & MHA
|
| 213 |
+
curr_stream.wait_stream(main_stream)
|
| 214 |
+
main_stream.wait_event(self.mha_event)
|
| 215 |
+
|
| 216 |
+
# normal residual / layernorm
|
| 217 |
+
if mixer_subset is not None:
|
| 218 |
+
residual = residual[:, mixer_subset]
|
| 219 |
+
|
| 220 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 221 |
+
if not self.fused_dropout_add_ln:
|
| 222 |
+
dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 223 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 224 |
+
hidden_states = self.norm2(
|
| 225 |
+
residual.to(dtype=self.norm2.weight.dtype)
|
| 226 |
+
)
|
| 227 |
+
if self.residual_in_fp32:
|
| 228 |
+
residual = residual.to(torch.float32)
|
| 229 |
+
else:
|
| 230 |
+
if self.drop_path2.p == 0 or not self.training:
|
| 231 |
+
rowscale2 = None
|
| 232 |
+
else:
|
| 233 |
+
rowscale2 = self.drop_path2(
|
| 234 |
+
torch.ones(
|
| 235 |
+
hidden_states.shape[:-1],
|
| 236 |
+
device=hidden_states.device,
|
| 237 |
+
dtype=hidden_states.dtype,
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
+
if hidden_states.shape != residual.shape:
|
| 241 |
+
hidden_states = hidden_states.view(residual.shape)
|
| 242 |
+
hidden_states, residual = layer_norm_fn(
|
| 243 |
+
hidden_states,
|
| 244 |
+
self.norm2.weight,
|
| 245 |
+
self.norm2.bias,
|
| 246 |
+
residual=residual,
|
| 247 |
+
eps=self.norm2.eps,
|
| 248 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
| 249 |
+
rowscale=rowscale2,
|
| 250 |
+
prenorm=True,
|
| 251 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 252 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
|
| 256 |
+
if self.skip_mlp_router:
|
| 257 |
+
hidden_states = self.mlp(hidden_states, index_vec=None)
|
| 258 |
+
else:
|
| 259 |
+
curr_stream.wait_stream(self.sparse_stream)
|
| 260 |
+
main_stream.wait_event(self.mlp_event)
|
| 261 |
+
hidden_states = self.mlp(hidden_states, index_vec=index_vec)
|
| 262 |
+
curr_stream.wait_stream(main_stream)
|
| 263 |
+
curr_stream.wait_stream(self.sparse_stream)
|
| 264 |
+
|
| 265 |
+
return hidden_states, residual
|
| 266 |
+
|
| 267 |
+
def tp_decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
|
| 268 |
+
"""
|
| 269 |
+
Tensor Parallel Decode Forward
|
| 270 |
+
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
curr_stream = torch.cuda.current_stream()
|
| 274 |
+
self.sparse_stream.wait_stream(curr_stream)
|
| 275 |
+
# self.main_stream.wait_stream(curr_stream)
|
| 276 |
+
|
| 277 |
+
router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
|
| 278 |
+
|
| 279 |
+
if self.mlp_topk > 0.8 * self.total_neurons:
|
| 280 |
+
self.skip_mlp_router = True
|
| 281 |
+
else:
|
| 282 |
+
self.skip_mlp_router = False
|
| 283 |
+
|
| 284 |
+
# attention router is synchronous
|
| 285 |
+
batch_head_idx = self.mha_router._select_heads(router_inputs)
|
| 286 |
+
|
| 287 |
+
# mlp router is asynchronous
|
| 288 |
+
if not self.skip_mlp_router:
|
| 289 |
+
with torch.cuda.stream(self.sparse_stream):
|
| 290 |
+
index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
|
| 291 |
+
self.sparse_stream.record_event(self.mlp_event)
|
| 292 |
+
|
| 293 |
+
hidden_states = self.mixer(hidden_states, **mixer_kwargs, batch_head_idx=batch_head_idx)
|
| 294 |
+
|
| 295 |
+
if mixer_subset is not None:
|
| 296 |
+
residual = residual[:, mixer_subset]
|
| 297 |
+
|
| 298 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 299 |
+
if not self.fused_dropout_add_ln:
|
| 300 |
+
dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 301 |
+
if dropped.shape != residual.shape:
|
| 302 |
+
dropped = dropped.view(residual.shape)
|
| 303 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 304 |
+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
| 305 |
+
if self.residual_in_fp32:
|
| 306 |
+
residual = residual.to(torch.float32)
|
| 307 |
+
else:
|
| 308 |
+
if self.drop_path2.p == 0 or not self.training:
|
| 309 |
+
rowscale2 = None
|
| 310 |
+
else:
|
| 311 |
+
rowscale2 = self.drop_path2(
|
| 312 |
+
torch.ones(
|
| 313 |
+
hidden_states.shape[:-1],
|
| 314 |
+
device=hidden_states.device,
|
| 315 |
+
dtype=hidden_states.dtype,
|
| 316 |
+
)
|
| 317 |
+
)
|
| 318 |
+
if hidden_states.shape != residual.shape:
|
| 319 |
+
hidden_states = hidden_states.view(residual.shape)
|
| 320 |
+
hidden_states, residual = layer_norm_fn(
|
| 321 |
+
hidden_states,
|
| 322 |
+
self.norm2.weight,
|
| 323 |
+
self.norm2.bias,
|
| 324 |
+
residual=residual,
|
| 325 |
+
eps=self.norm2.eps,
|
| 326 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
| 327 |
+
rowscale=rowscale2,
|
| 328 |
+
prenorm=True,
|
| 329 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 330 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# curr_stream.wait_stream(self.sparse_stream)
|
| 334 |
+
if self.skip_mlp_router:
|
| 335 |
+
hidden_states = self.mlp(hidden_states, index_vec=None)
|
| 336 |
+
else:
|
| 337 |
+
curr_stream.wait_event(self.mlp_event)
|
| 338 |
+
hidden_states = self.mlp(hidden_states, index_vec=index_vec)
|
| 339 |
+
|
| 340 |
+
return hidden_states, residual
|
| 341 |
+
|
| 342 |
+
def attn_sparse_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
|
| 343 |
+
"""
|
| 344 |
+
Decode Forward with Sparse Attention Router
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
# We want to run MHA & mlp_router in parallel on different streams
|
| 348 |
+
router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
|
| 349 |
+
|
| 350 |
+
batch_head_idx = self.mha_router._select_heads(router_inputs)
|
| 351 |
+
|
| 352 |
+
# print(f"hidden_states shape: {hidden_states.shape}")
|
| 353 |
+
# print(f"hidden states: {hidden_states}")
|
| 354 |
+
hidden_states = self.mixer(hidden_states, batch_head_idx=batch_head_idx, **mixer_kwargs)
|
| 355 |
+
|
| 356 |
+
# normal residual / layernorm
|
| 357 |
+
if mixer_subset is not None:
|
| 358 |
+
residual = residual[:, mixer_subset]
|
| 359 |
+
|
| 360 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 361 |
+
if not self.fused_dropout_add_ln:
|
| 362 |
+
dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 363 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 364 |
+
hidden_states = self.norm2(
|
| 365 |
+
residual.to(dtype=self.norm2.weight.dtype)
|
| 366 |
+
)
|
| 367 |
+
if self.residual_in_fp32:
|
| 368 |
+
residual = residual.to(torch.float32)
|
| 369 |
+
else:
|
| 370 |
+
if self.drop_path2.p == 0 or not self.training:
|
| 371 |
+
rowscale2 = None
|
| 372 |
+
else:
|
| 373 |
+
rowscale2 = self.drop_path2(
|
| 374 |
+
torch.ones(hidden_states.shape[:-1], device=hidden_states.device, dtype=hidden_states.dtype,)
|
| 375 |
+
)
|
| 376 |
+
if hidden_states.shape != residual.shape:
|
| 377 |
+
hidden_states = hidden_states.view(residual.shape)
|
| 378 |
+
hidden_states, residual = layer_norm_fn(hidden_states, self.norm2.weight, self.norm2.bias, residual=residual,
|
| 379 |
+
eps=self.norm2.eps, dropout_p=self.dropout2.p if self.training else 0.0,
|
| 380 |
+
rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32,
|
| 381 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),)
|
| 382 |
+
|
| 383 |
+
# hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
|
| 384 |
+
hidden_states = self.mlp(hidden_states)
|
| 385 |
+
|
| 386 |
+
return hidden_states, residual
|
| 387 |
+
|
| 388 |
+
def mlp_sparse_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
|
| 389 |
+
""" Single GPU Decode Forward
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
hidden_states (Tensor): _description_
|
| 393 |
+
residual (Optional[Tensor], optional): _description_. Defaults to None.
|
| 394 |
+
mixer_subset (_type_, optional): _description_. Defaults to None.
|
| 395 |
+
"""
|
| 396 |
+
curr_stream = torch.cuda.current_stream()
|
| 397 |
+
|
| 398 |
+
# We want to run MHA & mlp_router in parallel on different streams
|
| 399 |
+
router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
|
| 400 |
+
self.main_stream.wait_stream(curr_stream)
|
| 401 |
+
self.sparse_stream.wait_stream(curr_stream)
|
| 402 |
+
main_stream = self.main_stream
|
| 403 |
+
|
| 404 |
+
# if mlp_topk > th * total_neurons, skip mlp router
|
| 405 |
+
|
| 406 |
+
if self.mlp_topk > 0.8 * self.total_neurons:
|
| 407 |
+
self.skip_mlp_router = True
|
| 408 |
+
else:
|
| 409 |
+
self.skip_mlp_router = False
|
| 410 |
+
|
| 411 |
+
# [Sparse stream] mlp_router
|
| 412 |
+
if not self.skip_mlp_router:
|
| 413 |
+
with torch.cuda.stream(self.sparse_stream):
|
| 414 |
+
index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
|
| 415 |
+
self.sparse_stream.record_event(self.mlp_event)
|
| 416 |
+
|
| 417 |
+
# [Main stream] MHA
|
| 418 |
+
with torch.cuda.stream(main_stream):
|
| 419 |
+
# batch_head_idx = self.mha_router._select_heads(router_inputs)
|
| 420 |
+
hidden_states = self.mixer(
|
| 421 |
+
hidden_states,
|
| 422 |
+
batch_head_idx=None,
|
| 423 |
+
**mixer_kwargs
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
main_stream.record_event(self.mha_event)
|
| 427 |
+
|
| 428 |
+
# Now we unify after both are done, then do the next steps
|
| 429 |
+
with torch.cuda.stream(main_stream):
|
| 430 |
+
# Wait on router & MHA
|
| 431 |
+
curr_stream.wait_stream(main_stream)
|
| 432 |
+
main_stream.wait_event(self.mha_event)
|
| 433 |
+
|
| 434 |
+
# normal residual / layernorm
|
| 435 |
+
if mixer_subset is not None:
|
| 436 |
+
residual = residual[:, mixer_subset]
|
| 437 |
+
|
| 438 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 439 |
+
if not self.fused_dropout_add_ln:
|
| 440 |
+
dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 441 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 442 |
+
hidden_states = self.norm2(
|
| 443 |
+
residual.to(dtype=self.norm2.weight.dtype)
|
| 444 |
+
)
|
| 445 |
+
if self.residual_in_fp32:
|
| 446 |
+
residual = residual.to(torch.float32)
|
| 447 |
+
else:
|
| 448 |
+
if self.drop_path2.p == 0 or not self.training:
|
| 449 |
+
rowscale2 = None
|
| 450 |
+
else:
|
| 451 |
+
rowscale2 = self.drop_path2(
|
| 452 |
+
torch.ones(
|
| 453 |
+
hidden_states.shape[:-1],
|
| 454 |
+
device=hidden_states.device,
|
| 455 |
+
dtype=hidden_states.dtype,
|
| 456 |
+
)
|
| 457 |
+
)
|
| 458 |
+
if hidden_states.shape != residual.shape:
|
| 459 |
+
hidden_states = hidden_states.view(residual.shape)
|
| 460 |
+
hidden_states, residual = layer_norm_fn(
|
| 461 |
+
hidden_states,
|
| 462 |
+
self.norm2.weight,
|
| 463 |
+
self.norm2.bias,
|
| 464 |
+
residual=residual,
|
| 465 |
+
eps=self.norm2.eps,
|
| 466 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
| 467 |
+
rowscale=rowscale2,
|
| 468 |
+
prenorm=True,
|
| 469 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 470 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
|
| 474 |
+
if self.skip_mlp_router:
|
| 475 |
+
hidden_states = self.mlp(hidden_states, index_vec=None)
|
| 476 |
+
else:
|
| 477 |
+
curr_stream.wait_stream(self.sparse_stream)
|
| 478 |
+
main_stream.wait_event(self.mlp_event)
|
| 479 |
+
hidden_states = self.mlp(hidden_states, index_vec=index_vec)
|
| 480 |
+
curr_stream.wait_stream(main_stream)
|
| 481 |
+
curr_stream.wait_stream(self.sparse_stream)
|
| 482 |
+
|
| 483 |
+
return hidden_states, residual
|
| 484 |
+
|
| 485 |
+
def forward(
|
| 486 |
+
self,
|
| 487 |
+
hidden_states: Tensor,
|
| 488 |
+
residual: Optional[Tensor] = None,
|
| 489 |
+
mixer_subset=None,
|
| 490 |
+
mixer_kwargs=None,
|
| 491 |
+
mlp_topk=None,
|
| 492 |
+
attn_topk=None,
|
| 493 |
+
):
|
| 494 |
+
"""
|
| 495 |
+
This forward pass includes concurrency logic in the decode branch.
|
| 496 |
+
If you're capturing with a CUDA graph, the concurrency (two-stream usage) must be
|
| 497 |
+
inside the captured region so that the replay reproduces the parallel streams.
|
| 498 |
+
"""
|
| 499 |
+
|
| 500 |
+
# simulation values
|
| 501 |
+
if mlp_topk is not None:
|
| 502 |
+
self.mlp_topk = mlp_topk
|
| 503 |
+
|
| 504 |
+
if attn_topk is not None:
|
| 505 |
+
self.mha_router.topk = attn_topk
|
| 506 |
+
|
| 507 |
+
if mixer_kwargs is None:
|
| 508 |
+
mixer_kwargs = {"inference_params": None}
|
| 509 |
+
else:
|
| 510 |
+
# Ensure 'inference_params' key exists
|
| 511 |
+
if "inference_params" not in mixer_kwargs:
|
| 512 |
+
mixer_kwargs["inference_params"] = None
|
| 513 |
+
|
| 514 |
+
if self.prenorm:
|
| 515 |
+
# --- 1) Prenorm’s dropout/add/layernorm
|
| 516 |
+
if not self.fused_dropout_add_ln:
|
| 517 |
+
dropped = self.drop_path1(self.dropout1(hidden_states))
|
| 518 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 519 |
+
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
| 520 |
+
if self.residual_in_fp32:
|
| 521 |
+
residual = residual.to(torch.float32)
|
| 522 |
+
else:
|
| 523 |
+
# fused dropout + add + layernorm
|
| 524 |
+
if self.drop_path1.p == 0 or not self.training:
|
| 525 |
+
rowscale1 = None
|
| 526 |
+
else:
|
| 527 |
+
rowscale1 = self.drop_path1(
|
| 528 |
+
torch.ones(
|
| 529 |
+
hidden_states.shape[:-1],
|
| 530 |
+
device=hidden_states.device,
|
| 531 |
+
dtype=hidden_states.dtype,
|
| 532 |
+
)
|
| 533 |
+
)
|
| 534 |
+
if residual is not None and hidden_states.shape != residual.shape:
|
| 535 |
+
hidden_states = hidden_states.view(residual.shape)
|
| 536 |
+
hidden_states, residual = layer_norm_fn(
|
| 537 |
+
hidden_states,
|
| 538 |
+
self.norm1.weight,
|
| 539 |
+
self.norm1.bias,
|
| 540 |
+
residual=residual,
|
| 541 |
+
eps=self.norm1.eps,
|
| 542 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 543 |
+
rowscale=rowscale1,
|
| 544 |
+
prenorm=True,
|
| 545 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 546 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
if mixer_subset is not None:
|
| 550 |
+
mixer_kwargs["mixer_subset"] = mixer_subset
|
| 551 |
+
|
| 552 |
+
# Check if we are in the prefill or decode stage
|
| 553 |
+
prefill_stage = (
|
| 554 |
+
mixer_kwargs["inference_params"] is None
|
| 555 |
+
or mixer_kwargs["inference_params"].seqlen_offset == 0
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
if prefill_stage:
|
| 559 |
+
# --- 2) Prefill stage (no concurrency): just do normal forward
|
| 560 |
+
hidden_states, residual = self.prefill_forward(hidden_states, residual, mixer_kwargs, mixer_subset)
|
| 561 |
+
|
| 562 |
+
else:
|
| 563 |
+
# --- 3) Decode stage:
|
| 564 |
+
if self.mlp_router is None:
|
| 565 |
+
# decode stage with only attention router, works with both single gpu and tensor parallel
|
| 566 |
+
hidden_states, residual = self.attn_sparse_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
|
| 567 |
+
else:
|
| 568 |
+
if not self.use_tensor_parallel:
|
| 569 |
+
if self.mha_router is None:
|
| 570 |
+
# decode stage with mlp routers (opt models and single gpu)
|
| 571 |
+
hidden_states, residual = self.mlp_sparse_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
|
| 572 |
+
else:
|
| 573 |
+
# decode stage with mlp and attention routers (opt models and single gpu)
|
| 574 |
+
hidden_states, residual = self.decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
|
| 575 |
+
else:
|
| 576 |
+
# uses both mlp and attention routers in tensor parallel
|
| 577 |
+
hidden_states, residual = self.tp_decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
|
| 578 |
+
|
| 579 |
+
return hidden_states, residual
|
| 580 |
+
|
| 581 |
+
else:
|
| 582 |
+
# post-norm architecture not implemented here
|
| 583 |
+
raise NotImplementedError
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
# class SelectBlock(nn.Module):
|
| 587 |
+
# def __init__(
|
| 588 |
+
# self,
|
| 589 |
+
# dim,
|
| 590 |
+
# mixer_cls=None,
|
| 591 |
+
# mlp_cls=None,
|
| 592 |
+
# mlp_router=None,
|
| 593 |
+
# mha_router=None,
|
| 594 |
+
# norm_cls=nn.LayerNorm,
|
| 595 |
+
# dropout_cls=nn.Dropout,
|
| 596 |
+
# prenorm=True,
|
| 597 |
+
# resid_dropout1=0.0,
|
| 598 |
+
# resid_dropout2=0.0,
|
| 599 |
+
# drop_path1=0.0,
|
| 600 |
+
# drop_path2=0.0,
|
| 601 |
+
# fused_dropout_add_ln=False,
|
| 602 |
+
# return_residual=False,
|
| 603 |
+
# residual_in_fp32=False,
|
| 604 |
+
# sequence_parallel=False,
|
| 605 |
+
# mark_shared_params=False,
|
| 606 |
+
# ):
|
| 607 |
+
# """
|
| 608 |
+
# For prenorm=True, this Block has a slightly different structure compared to a regular
|
| 609 |
+
# prenorm Transformer block.
|
| 610 |
+
|
| 611 |
+
# The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
|
| 612 |
+
# Here we do: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, etc.
|
| 613 |
+
|
| 614 |
+
# If you want to do concurrency with CUDA graphs, your shapes must remain fixed
|
| 615 |
+
# (batch_size, seq_len, etc.) across captures and replays. Also avoid any operations
|
| 616 |
+
# that cause dynamic shape changes or memory allocations.
|
| 617 |
+
# """
|
| 618 |
+
# super().__init__()
|
| 619 |
+
# self.prenorm = prenorm
|
| 620 |
+
# self.fused_dropout_add_ln = fused_dropout_add_ln
|
| 621 |
+
# self.return_residual = return_residual
|
| 622 |
+
# self.residual_in_fp32 = residual_in_fp32
|
| 623 |
+
# if self.residual_in_fp32:
|
| 624 |
+
# assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
|
| 625 |
+
|
| 626 |
+
# assert mixer_cls is not None and mlp_cls is not None, (
|
| 627 |
+
# "mixer_cls and mlp_cls cannot be None in SelectBlock"
|
| 628 |
+
# )
|
| 629 |
+
|
| 630 |
+
# # MHA & MLP submodules
|
| 631 |
+
# self.mixer = mixer_cls(dim)
|
| 632 |
+
# self.dropout1 = dropout_cls(resid_dropout1)
|
| 633 |
+
# self.drop_path1 = StochasticDepth(drop_path1, mode="row")
|
| 634 |
+
# self.norm1 = norm_cls(dim)
|
| 635 |
+
# self.mlp = mlp_cls(dim)
|
| 636 |
+
|
| 637 |
+
# # Routers
|
| 638 |
+
# self.mlp_router = mlp_router(dim)
|
| 639 |
+
# self.mha_router = mha_router(dim)
|
| 640 |
+
|
| 641 |
+
# if not isinstance(self.mlp, nn.Identity):
|
| 642 |
+
# self.dropout2 = dropout_cls(resid_dropout2)
|
| 643 |
+
# self.drop_path2 = StochasticDepth(drop_path2, mode="row")
|
| 644 |
+
# self.norm2 = norm_cls(dim)
|
| 645 |
+
|
| 646 |
+
# if self.fused_dropout_add_ln:
|
| 647 |
+
# assert layer_norm_fn is not None, "Triton layer_norm_fn not installed"
|
| 648 |
+
# assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(self.dropout1, nn.Dropout)
|
| 649 |
+
|
| 650 |
+
# # Mark the norm parameters for sequence parallel / shared params if needed
|
| 651 |
+
# if sequence_parallel:
|
| 652 |
+
# for p in self.norm1.parameters():
|
| 653 |
+
# p._sequence_parallel = True
|
| 654 |
+
# if hasattr(self, "norm2"):
|
| 655 |
+
# for p in self.norm2.parameters():
|
| 656 |
+
# p._sequence_parallel = True
|
| 657 |
+
# if mark_shared_params:
|
| 658 |
+
# for p in self.norm1.parameters():
|
| 659 |
+
# p._shared_params = True
|
| 660 |
+
# if hasattr(self, "norm2"):
|
| 661 |
+
# for p in self.norm2.parameters():
|
| 662 |
+
# p._shared_params = True
|
| 663 |
+
|
| 664 |
+
# self.mlp_topk = None
|
| 665 |
+
# self.skip_mlp_router = False
|
| 666 |
+
# self.skip_attn_router = False
|
| 667 |
+
|
| 668 |
+
# # We'll use an extra stream for concurrency
|
| 669 |
+
# self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0)
|
| 670 |
+
# self.main_stream = torch.cuda.Stream(device="cuda", priority=-5)
|
| 671 |
+
# # We'll record events to coordinate concurrency
|
| 672 |
+
# self.mha_event = torch.cuda.Event(enable_timing=False, blocking=False)
|
| 673 |
+
# self.mlp_event = torch.cuda.Event(enable_timing=False, blocking=False)
|
| 674 |
+
|
| 675 |
+
# self.use_tensor_parallel = mark_shared_params
|
| 676 |
+
|
| 677 |
+
# if self.use_tensor_parallel:
|
| 678 |
+
# # TODO: save the routers in the mixer and mlp classes
|
| 679 |
+
# # save the stream and events in the mixer and mlp classes
|
| 680 |
+
# pass
|
| 681 |
+
|
| 682 |
+
# def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 683 |
+
# return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 684 |
+
|
| 685 |
+
# def prefill_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_kwargs=None, mixer_subset=None):
|
| 686 |
+
# hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
| 687 |
+
|
| 688 |
+
# if mixer_subset is not None:
|
| 689 |
+
# residual = residual[:, mixer_subset]
|
| 690 |
+
|
| 691 |
+
# if not isinstance(self.mlp, nn.Identity):
|
| 692 |
+
# if not self.fused_dropout_add_ln:
|
| 693 |
+
# dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 694 |
+
# if dropped.shape != residual.shape:
|
| 695 |
+
# dropped = dropped.view(residual.shape)
|
| 696 |
+
# residual = (dropped + residual) if residual is not None else dropped
|
| 697 |
+
# hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
| 698 |
+
# if self.residual_in_fp32:
|
| 699 |
+
# residual = residual.to(torch.float32)
|
| 700 |
+
# else:
|
| 701 |
+
# if self.drop_path2.p == 0 or not self.training:
|
| 702 |
+
# rowscale2 = None
|
| 703 |
+
# else:
|
| 704 |
+
# rowscale2 = self.drop_path2(
|
| 705 |
+
# torch.ones(
|
| 706 |
+
# hidden_states.shape[:-1],
|
| 707 |
+
# device=hidden_states.device,
|
| 708 |
+
# dtype=hidden_states.dtype,
|
| 709 |
+
# )
|
| 710 |
+
# )
|
| 711 |
+
# if hidden_states.shape != residual.shape:
|
| 712 |
+
# hidden_states = hidden_states.view(residual.shape)
|
| 713 |
+
# hidden_states, residual = layer_norm_fn(
|
| 714 |
+
# hidden_states,
|
| 715 |
+
# self.norm2.weight,
|
| 716 |
+
# self.norm2.bias,
|
| 717 |
+
# residual=residual,
|
| 718 |
+
# eps=self.norm2.eps,
|
| 719 |
+
# dropout_p=self.dropout2.p if self.training else 0.0,
|
| 720 |
+
# rowscale=rowscale2,
|
| 721 |
+
# prenorm=True,
|
| 722 |
+
# residual_in_fp32=self.residual_in_fp32,
|
| 723 |
+
# is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 724 |
+
# )
|
| 725 |
+
# hidden_states = self.mlp(hidden_states)
|
| 726 |
+
# return hidden_states, residual
|
| 727 |
+
|
| 728 |
+
# def decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
|
| 729 |
+
# """ Single GPU Decode Forward
|
| 730 |
+
|
| 731 |
+
# Args:
|
| 732 |
+
# hidden_states (Tensor): _description_
|
| 733 |
+
# residual (Optional[Tensor], optional): _description_. Defaults to None.
|
| 734 |
+
# mixer_subset (_type_, optional): _description_. Defaults to None.
|
| 735 |
+
# """
|
| 736 |
+
# curr_stream = torch.cuda.current_stream()
|
| 737 |
+
|
| 738 |
+
# # We want to run MHA & mlp_router in parallel on different streams
|
| 739 |
+
# router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
|
| 740 |
+
# self.main_stream.wait_stream(curr_stream)
|
| 741 |
+
# self.sparse_stream.wait_stream(curr_stream)
|
| 742 |
+
|
| 743 |
+
# # We'll do MHA on the "main_stream" and the router on "sparse_stream"
|
| 744 |
+
# main_stream = self.main_stream
|
| 745 |
+
# # In a captured region, each 'with torch.cuda.stream(...)' block
|
| 746 |
+
# # is replayed in concurrency. The shape must remain consistent.
|
| 747 |
+
|
| 748 |
+
# # [Sparse stream] mlp_router
|
| 749 |
+
# if not self.skip_mlp_router:
|
| 750 |
+
# with torch.cuda.stream(self.sparse_stream): # <-- CHANGED
|
| 751 |
+
# # index_size, index_vec = self.mlp_router._select_neurons_cuda_safe(router_inputs) # need to fix this; make CUDA Graph safe
|
| 752 |
+
# # vec = self.mlp_router(router_inputs)
|
| 753 |
+
# index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
|
| 754 |
+
# self.sparse_stream.record_event(self.mlp_event)
|
| 755 |
+
|
| 756 |
+
# # [Main stream] MHA
|
| 757 |
+
# with torch.cuda.stream(main_stream): # <-- CHANGED
|
| 758 |
+
# batch_head_idx = self.mha_router._select_heads(router_inputs)
|
| 759 |
+
# hidden_states = self.mixer(
|
| 760 |
+
# hidden_states,
|
| 761 |
+
# batch_head_idx=batch_head_idx,
|
| 762 |
+
# # batch_head_idx=None,
|
| 763 |
+
# **mixer_kwargs
|
| 764 |
+
# )
|
| 765 |
+
# main_stream.record_event(self.mha_event)
|
| 766 |
+
|
| 767 |
+
# # Now we unify after both are done, then do the next steps
|
| 768 |
+
# with torch.cuda.stream(main_stream): # <-- CHANGED
|
| 769 |
+
# # Wait on router & MHA
|
| 770 |
+
# curr_stream.wait_stream(main_stream)
|
| 771 |
+
# main_stream.wait_event(self.mha_event)
|
| 772 |
+
|
| 773 |
+
# # normal residual / layernorm
|
| 774 |
+
# if mixer_subset is not None:
|
| 775 |
+
# residual = residual[:, mixer_subset]
|
| 776 |
+
|
| 777 |
+
# if not isinstance(self.mlp, nn.Identity):
|
| 778 |
+
# if not self.fused_dropout_add_ln:
|
| 779 |
+
# dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 780 |
+
# residual = (dropped + residual) if residual is not None else dropped
|
| 781 |
+
# hidden_states = self.norm2(
|
| 782 |
+
# residual.to(dtype=self.norm2.weight.dtype)
|
| 783 |
+
# )
|
| 784 |
+
# if self.residual_in_fp32:
|
| 785 |
+
# residual = residual.to(torch.float32)
|
| 786 |
+
# else:
|
| 787 |
+
# if self.drop_path2.p == 0 or not self.training:
|
| 788 |
+
# rowscale2 = None
|
| 789 |
+
# else:
|
| 790 |
+
# rowscale2 = self.drop_path2(
|
| 791 |
+
# torch.ones(
|
| 792 |
+
# hidden_states.shape[:-1],
|
| 793 |
+
# device=hidden_states.device,
|
| 794 |
+
# dtype=hidden_states.dtype,
|
| 795 |
+
# )
|
| 796 |
+
# )
|
| 797 |
+
# if hidden_states.shape != residual.shape:
|
| 798 |
+
# hidden_states = hidden_states.view(residual.shape)
|
| 799 |
+
# hidden_states, residual = layer_norm_fn(
|
| 800 |
+
# hidden_states,
|
| 801 |
+
# self.norm2.weight,
|
| 802 |
+
# self.norm2.bias,
|
| 803 |
+
# residual=residual,
|
| 804 |
+
# eps=self.norm2.eps,
|
| 805 |
+
# dropout_p=self.dropout2.p if self.training else 0.0,
|
| 806 |
+
# rowscale=rowscale2,
|
| 807 |
+
# prenorm=True,
|
| 808 |
+
# residual_in_fp32=self.residual_in_fp32,
|
| 809 |
+
# is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 810 |
+
# )
|
| 811 |
+
|
| 812 |
+
# # Finally do MLP with the router's index vector
|
| 813 |
+
# curr_stream.wait_stream(self.sparse_stream)
|
| 814 |
+
# main_stream.wait_event(self.mlp_event)
|
| 815 |
+
|
| 816 |
+
# # hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
|
| 817 |
+
# if self.skip_mlp_router:
|
| 818 |
+
# hidden_states = self.mlp(hidden_states, index_vec=None)
|
| 819 |
+
# else:
|
| 820 |
+
# hidden_states = self.mlp(hidden_states, index_vec=index_vec)
|
| 821 |
+
# curr_stream.wait_stream(main_stream)
|
| 822 |
+
# curr_stream.wait_stream(self.sparse_stream)
|
| 823 |
+
|
| 824 |
+
# return hidden_states, residual
|
| 825 |
+
|
| 826 |
+
# def tp_decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
|
| 827 |
+
# """
|
| 828 |
+
# Tensor Parallel Decode Forward
|
| 829 |
+
|
| 830 |
+
# Args:
|
| 831 |
+
# hidden_states (Tensor): _description_
|
| 832 |
+
# residual (Optional[Tensor], optional): _description_. Defaults to None.
|
| 833 |
+
# mixer_subset (_type_, optional): _description_. Defaults to None.
|
| 834 |
+
# """
|
| 835 |
+
# # TODO: need to add routing
|
| 836 |
+
|
| 837 |
+
# hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
| 838 |
+
|
| 839 |
+
# if mixer_subset is not None:
|
| 840 |
+
# residual = residual[:, mixer_subset]
|
| 841 |
+
|
| 842 |
+
# if not isinstance(self.mlp, nn.Identity):
|
| 843 |
+
# if not self.fused_dropout_add_ln:
|
| 844 |
+
# dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 845 |
+
# if dropped.shape != residual.shape:
|
| 846 |
+
# dropped = dropped.view(residual.shape)
|
| 847 |
+
# residual = (dropped + residual) if residual is not None else dropped
|
| 848 |
+
# hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
| 849 |
+
# if self.residual_in_fp32:
|
| 850 |
+
# residual = residual.to(torch.float32)
|
| 851 |
+
# else:
|
| 852 |
+
# if self.drop_path2.p == 0 or not self.training:
|
| 853 |
+
# rowscale2 = None
|
| 854 |
+
# else:
|
| 855 |
+
# rowscale2 = self.drop_path2(
|
| 856 |
+
# torch.ones(
|
| 857 |
+
# hidden_states.shape[:-1],
|
| 858 |
+
# device=hidden_states.device,
|
| 859 |
+
# dtype=hidden_states.dtype,
|
| 860 |
+
# )
|
| 861 |
+
# )
|
| 862 |
+
# if hidden_states.shape != residual.shape:
|
| 863 |
+
# hidden_states = hidden_states.view(residual.shape)
|
| 864 |
+
# hidden_states, residual = layer_norm_fn(
|
| 865 |
+
# hidden_states,
|
| 866 |
+
# self.norm2.weight,
|
| 867 |
+
# self.norm2.bias,
|
| 868 |
+
# residual=residual,
|
| 869 |
+
# eps=self.norm2.eps,
|
| 870 |
+
# dropout_p=self.dropout2.p if self.training else 0.0,
|
| 871 |
+
# rowscale=rowscale2,
|
| 872 |
+
# prenorm=True,
|
| 873 |
+
# residual_in_fp32=self.residual_in_fp32,
|
| 874 |
+
# is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 875 |
+
# )
|
| 876 |
+
# hidden_states = self.mlp(hidden_states)
|
| 877 |
+
# return hidden_states, residual
|
| 878 |
+
|
| 879 |
+
# def forward(
|
| 880 |
+
# self,
|
| 881 |
+
# hidden_states: Tensor,
|
| 882 |
+
# residual: Optional[Tensor] = None,
|
| 883 |
+
# mixer_subset=None,
|
| 884 |
+
# mixer_kwargs=None,
|
| 885 |
+
# ):
|
| 886 |
+
# """
|
| 887 |
+
# This forward pass includes concurrency logic in the decode branch.
|
| 888 |
+
# If you're capturing with a CUDA graph, the concurrency (two-stream usage) must be
|
| 889 |
+
# inside the captured region so that the replay reproduces the parallel streams.
|
| 890 |
+
# """
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
# if mixer_kwargs is None:
|
| 894 |
+
# mixer_kwargs = {"inference_params": None}
|
| 895 |
+
# else:
|
| 896 |
+
# # Ensure 'inference_params' key exists
|
| 897 |
+
# if "inference_params" not in mixer_kwargs:
|
| 898 |
+
# mixer_kwargs["inference_params"] = None
|
| 899 |
+
|
| 900 |
+
# if self.prenorm:
|
| 901 |
+
# # --- 1) Prenorm’s dropout/add/layernorm
|
| 902 |
+
# if not self.fused_dropout_add_ln:
|
| 903 |
+
# dropped = self.drop_path1(self.dropout1(hidden_states))
|
| 904 |
+
# residual = (dropped + residual) if residual is not None else dropped
|
| 905 |
+
# hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
| 906 |
+
# if self.residual_in_fp32:
|
| 907 |
+
# residual = residual.to(torch.float32)
|
| 908 |
+
# else:
|
| 909 |
+
# # fused dropout + add + layernorm
|
| 910 |
+
# if self.drop_path1.p == 0 or not self.training:
|
| 911 |
+
# rowscale1 = None
|
| 912 |
+
# else:
|
| 913 |
+
# rowscale1 = self.drop_path1(
|
| 914 |
+
# torch.ones(
|
| 915 |
+
# hidden_states.shape[:-1],
|
| 916 |
+
# device=hidden_states.device,
|
| 917 |
+
# dtype=hidden_states.dtype,
|
| 918 |
+
# )
|
| 919 |
+
# )
|
| 920 |
+
# if residual is not None and hidden_states.shape != residual.shape:
|
| 921 |
+
# hidden_states = hidden_states.view(residual.shape)
|
| 922 |
+
# hidden_states, residual = layer_norm_fn(
|
| 923 |
+
# hidden_states,
|
| 924 |
+
# self.norm1.weight,
|
| 925 |
+
# self.norm1.bias,
|
| 926 |
+
# residual=residual,
|
| 927 |
+
# eps=self.norm1.eps,
|
| 928 |
+
# dropout_p=self.dropout1.p if self.training else 0.0,
|
| 929 |
+
# rowscale=rowscale1,
|
| 930 |
+
# prenorm=True,
|
| 931 |
+
# residual_in_fp32=self.residual_in_fp32,
|
| 932 |
+
# is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 933 |
+
# )
|
| 934 |
+
|
| 935 |
+
# if mixer_subset is not None:
|
| 936 |
+
# mixer_kwargs["mixer_subset"] = mixer_subset
|
| 937 |
+
|
| 938 |
+
# # Check if we are in the prefill or decode stage
|
| 939 |
+
# prefill_stage = (
|
| 940 |
+
# mixer_kwargs["inference_params"] is None
|
| 941 |
+
# or mixer_kwargs["inference_params"].seqlen_offset == 0
|
| 942 |
+
# )
|
| 943 |
+
|
| 944 |
+
# if prefill_stage:
|
| 945 |
+
# # --- 2) Prefill stage (no concurrency): just do normal forward
|
| 946 |
+
# hidden_states, residual = self.prefill_forward(hidden_states, residual, mixer_kwargs, mixer_subset)
|
| 947 |
+
|
| 948 |
+
# else:
|
| 949 |
+
# # # --- 3) Decode stage:
|
| 950 |
+
# if not self.use_tensor_parallel:
|
| 951 |
+
# hidden_states, residual = self.decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
|
| 952 |
+
# else:
|
| 953 |
+
# # routing is slightly different in tensor parallel; we overlap the router with allreduce
|
| 954 |
+
# hidden_states, residual = self.tp_decode_forward(hidden_states, residual, mixer_subset)
|
| 955 |
+
|
| 956 |
+
# return hidden_states, residual
|
| 957 |
+
|
| 958 |
+
# else:
|
| 959 |
+
# # post-norm architecture not implemented here
|
| 960 |
+
# raise NotImplementedError
|
HybridTensor/modules/SelectiveMHA.py
ADDED
|
@@ -0,0 +1,1579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from einops import rearrange, repeat
|
| 7 |
+
|
| 8 |
+
from flash_attn.utils.distributed import get_dim_for_local_rank
|
| 9 |
+
from flash_attn.utils.distributed import all_reduce
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from flash_attn import (
|
| 13 |
+
flash_attn_kvpacked_func,
|
| 14 |
+
flash_attn_qkvpacked_func,
|
| 15 |
+
flash_attn_varlen_kvpacked_func,
|
| 16 |
+
flash_attn_varlen_qkvpacked_func,
|
| 17 |
+
flash_attn_with_kvcache,
|
| 18 |
+
)
|
| 19 |
+
except ImportError:
|
| 20 |
+
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
| 21 |
+
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
| 22 |
+
flash_attn_with_kvcache = None
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear, fused_dense_func
|
| 26 |
+
except ImportError:
|
| 27 |
+
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from flash_attn.layers.rotary import RotaryEmbedding
|
| 31 |
+
except ImportError:
|
| 32 |
+
RotaryEmbedding = None
|
| 33 |
+
|
| 34 |
+
from flash_attn.modules.mha import SelfAttention, FlashSelfAttention, LinearResidual, FlashCrossAttention, CrossAttention
|
| 35 |
+
from flash_attn.modules.mha import get_alibi_slopes #, _update_kv_cache
|
| 36 |
+
from flash_attn.utils.generation import InferenceParams
|
| 37 |
+
|
| 38 |
+
# from HybridTensor.modules.references.mha_dejavu import ParallelTracker # use this in the full implementation
|
| 39 |
+
# from HybridTensor.modules.references.mha_dejavu import ParallelMHASparseAttMlp
|
| 40 |
+
# from HybridTensor.triton.references.attention_proj_sparse import qkv_proj_sparse
|
| 41 |
+
# from HybridTensor.triton.select_attn import select_attn
|
| 42 |
+
# from HybridTensor.triton.select_attn_64b_kernel import select_attn
|
| 43 |
+
from HybridTensor.triton.attn_interface import flash_attn_with_kvcache_triton
|
| 44 |
+
from HybridTensor.triton.select_attn_v1 import select_attn
|
| 45 |
+
from HybridTensor.utils.utils import arg_parser, generate_BH_index, generate_random_BH_index
|
| 46 |
+
from HybridTensor.utils.profiling import cuda_profiler
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class MHARouter(torch.nn.Module):
|
| 50 |
+
def __init__(self, embed_dim, low_rank_dim = None, out_dim = None, top_k = 0.5, device = None, dtype = None):
|
| 51 |
+
super(MHARouter, self).__init__()
|
| 52 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 53 |
+
self.model_dim = embed_dim
|
| 54 |
+
self.num_heads = out_dim
|
| 55 |
+
self.topk = top_k
|
| 56 |
+
|
| 57 |
+
self.linear1 = torch.nn.Linear(embed_dim, out_dim, bias = True, **factory_kwargs)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
out = self.linear1(x)
|
| 61 |
+
return out
|
| 62 |
+
|
| 63 |
+
def _select_heads(self, x, topk = None):
|
| 64 |
+
if topk is None:
|
| 65 |
+
topk = int(self.topk * self.num_heads)
|
| 66 |
+
else:
|
| 67 |
+
topk = int(self.num_heads * topk)
|
| 68 |
+
head_scores = self.forward(x)
|
| 69 |
+
_, selected_heads = torch.topk(head_scores, topk, dim=1)
|
| 70 |
+
|
| 71 |
+
return selected_heads
|
| 72 |
+
|
| 73 |
+
class ParallelMHARouter(torch.nn.Module):
|
| 74 |
+
def __init__(self, embed_dim, low_rank_dim, out_dim, top_k, process_group, sequence_parallel=False, device = None, dtype = None):
|
| 75 |
+
super(ParallelMHARouter, self).__init__()
|
| 76 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 77 |
+
self.model_dim = embed_dim
|
| 78 |
+
self.num_heads = out_dim
|
| 79 |
+
self.topk = top_k
|
| 80 |
+
world_size = torch.distributed.get_world_size(process_group)
|
| 81 |
+
self.local_heads = out_dim // world_size
|
| 82 |
+
|
| 83 |
+
self.linear1 = ColumnParallelLinear(
|
| 84 |
+
embed_dim,
|
| 85 |
+
out_dim,
|
| 86 |
+
process_group,
|
| 87 |
+
bias=True,
|
| 88 |
+
sequence_parallel=sequence_parallel,
|
| 89 |
+
**factory_kwargs,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
out = self.linear1(x)
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
def _select_heads(self, x, topk = None):
|
| 97 |
+
if topk is None:
|
| 98 |
+
topk = int(self.topk * self.local_heads)
|
| 99 |
+
else:
|
| 100 |
+
topk = int(self.local_heads * topk)
|
| 101 |
+
head_scores = self.forward(x)
|
| 102 |
+
# head_scores = head_scores.squeeze(1)
|
| 103 |
+
# print(f"Head Scores.shape: {head_scores.shape}")
|
| 104 |
+
_, selected_heads = torch.topk(head_scores, topk, dim=1)
|
| 105 |
+
# print(f"Selected Heads: {selected_heads}")
|
| 106 |
+
return selected_heads
|
| 107 |
+
|
| 108 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
| 109 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 110 |
+
# Pre-allocate memory for key-values for inference.
|
| 111 |
+
num_heads, head_dim = kv.shape[-2:]
|
| 112 |
+
if layer_idx not in inference_params.key_value_memory_dict:
|
| 113 |
+
kv_cache = torch.empty(
|
| 114 |
+
inference_params.max_batch_size,
|
| 115 |
+
inference_params.max_seqlen,
|
| 116 |
+
2,
|
| 117 |
+
num_heads,
|
| 118 |
+
head_dim,
|
| 119 |
+
dtype=kv.dtype,
|
| 120 |
+
device=kv.device,
|
| 121 |
+
)
|
| 122 |
+
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
| 123 |
+
else:
|
| 124 |
+
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
| 125 |
+
# Adjust key and value for inference
|
| 126 |
+
batch_start = inference_params.batch_size_offset
|
| 127 |
+
batch_end = batch_start + kv.shape[0]
|
| 128 |
+
sequence_start = inference_params.seqlen_offset
|
| 129 |
+
sequence_end = sequence_start + kv.shape[1]
|
| 130 |
+
assert batch_end <= kv_cache.shape[0]
|
| 131 |
+
assert sequence_end <= kv_cache.shape[1]
|
| 132 |
+
assert kv_cache is not None
|
| 133 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
| 134 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
| 135 |
+
|
| 136 |
+
class SMHA(nn.Module):
|
| 137 |
+
"""Multi-head self-attention and cross-attention with Triton decode kernels + Selective Attention"""
|
| 138 |
+
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
embed_dim,
|
| 142 |
+
num_heads,
|
| 143 |
+
num_heads_kv=None,
|
| 144 |
+
cross_attn=False,
|
| 145 |
+
qkv_proj_bias=True,
|
| 146 |
+
out_proj_bias=True,
|
| 147 |
+
dropout=0.0,
|
| 148 |
+
softmax_scale=None,
|
| 149 |
+
causal=False,
|
| 150 |
+
layer_idx=None,
|
| 151 |
+
dwconv=False,
|
| 152 |
+
rotary_emb_dim=0,
|
| 153 |
+
rotary_emb_base=10000.0,
|
| 154 |
+
rotary_emb_scale_base=None,
|
| 155 |
+
rotary_emb_interleaved=False,
|
| 156 |
+
use_alibi=False,
|
| 157 |
+
window_size=(-1, -1),
|
| 158 |
+
fused_bias_fc=False,
|
| 159 |
+
use_flash_attn=False,
|
| 160 |
+
return_residual=False,
|
| 161 |
+
checkpointing=False,
|
| 162 |
+
use_triton=True,
|
| 163 |
+
device=None,
|
| 164 |
+
dtype=None,
|
| 165 |
+
) -> None:
|
| 166 |
+
"""
|
| 167 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
| 168 |
+
return_residual: whether to return the input x along with the output. This is for
|
| 169 |
+
performance reason: for post-norm architecture, returning the input allows us
|
| 170 |
+
to fuse the backward of nn.Linear with the residual connection.
|
| 171 |
+
"""
|
| 172 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.embed_dim = embed_dim
|
| 175 |
+
self.cross_attn = cross_attn
|
| 176 |
+
self.causal = causal
|
| 177 |
+
self.layer_idx = layer_idx
|
| 178 |
+
self.dwconv = dwconv
|
| 179 |
+
self.rotary_emb_dim = rotary_emb_dim
|
| 180 |
+
self.use_flash_attn = use_flash_attn
|
| 181 |
+
self.return_residual = return_residual
|
| 182 |
+
self.checkpointing = checkpointing
|
| 183 |
+
self.use_triton = use_triton
|
| 184 |
+
if use_alibi:
|
| 185 |
+
assert use_flash_attn, "ALiBi code path requires flash_attn"
|
| 186 |
+
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
| 187 |
+
else:
|
| 188 |
+
alibi_slopes = None
|
| 189 |
+
if window_size != (-1, -1):
|
| 190 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 191 |
+
|
| 192 |
+
self.num_heads = num_heads
|
| 193 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 194 |
+
assert (
|
| 195 |
+
self.num_heads % self.num_heads_kv == 0
|
| 196 |
+
), "num_heads must be divisible by num_heads_kv"
|
| 197 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 198 |
+
self.head_dim = self.embed_dim // num_heads
|
| 199 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 200 |
+
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
| 201 |
+
|
| 202 |
+
if self.rotary_emb_dim > 0:
|
| 203 |
+
assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
|
| 204 |
+
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
| 205 |
+
self.rotary_emb = RotaryEmbedding(
|
| 206 |
+
self.rotary_emb_dim,
|
| 207 |
+
base=rotary_emb_base,
|
| 208 |
+
scale_base=rotary_emb_scale_base,
|
| 209 |
+
interleaved=rotary_emb_interleaved,
|
| 210 |
+
device=device,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if fused_bias_fc and FusedDense is None:
|
| 214 |
+
raise ImportError("fused_dense is not installed")
|
| 215 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 216 |
+
linear_resid_cls = (
|
| 217 |
+
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
| 218 |
+
)
|
| 219 |
+
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
| 220 |
+
inner_attn_cls = (
|
| 221 |
+
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 222 |
+
if use_flash_attn
|
| 223 |
+
else SelfAttention
|
| 224 |
+
)
|
| 225 |
+
inner_cross_attn_cls = (
|
| 226 |
+
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 227 |
+
if use_flash_attn
|
| 228 |
+
else CrossAttention
|
| 229 |
+
)
|
| 230 |
+
if not self.cross_attn:
|
| 231 |
+
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 232 |
+
else:
|
| 233 |
+
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 234 |
+
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 235 |
+
if self.dwconv:
|
| 236 |
+
if self.num_heads_kv == self.num_heads:
|
| 237 |
+
self.dwconv_qkv = nn.Conv1d(
|
| 238 |
+
qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
|
| 239 |
+
)
|
| 240 |
+
else:
|
| 241 |
+
self.dwconv_q = nn.Conv1d(
|
| 242 |
+
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
| 243 |
+
)
|
| 244 |
+
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
|
| 245 |
+
self.inner_attn = inner_attn_cls(
|
| 246 |
+
causal=causal,
|
| 247 |
+
softmax_scale=softmax_scale,
|
| 248 |
+
attention_dropout=dropout,
|
| 249 |
+
)
|
| 250 |
+
self.inner_cross_attn = inner_cross_attn_cls(
|
| 251 |
+
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 252 |
+
)
|
| 253 |
+
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
| 254 |
+
|
| 255 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 256 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
| 257 |
+
device = self.out_proj.weight.device
|
| 258 |
+
return torch.empty(
|
| 259 |
+
batch_size,
|
| 260 |
+
max_seqlen,
|
| 261 |
+
2,
|
| 262 |
+
self.num_heads_kv,
|
| 263 |
+
self.head_dim,
|
| 264 |
+
dtype=dtype,
|
| 265 |
+
device=device,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
def _update_kv_cache(self, kv, inference_params):
|
| 269 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 270 |
+
assert not self.dwconv, "Generation does not support dwconv yet"
|
| 271 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 272 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 273 |
+
|
| 274 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
| 275 |
+
"""
|
| 276 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
| 277 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
| 278 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
| 279 |
+
"""
|
| 280 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
| 281 |
+
assert self.use_flash_attn
|
| 282 |
+
if self.rotary_emb_dim > 0:
|
| 283 |
+
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
| 284 |
+
self.rotary_emb._update_cos_sin_cache(
|
| 285 |
+
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
| 286 |
+
)
|
| 287 |
+
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
| 288 |
+
else:
|
| 289 |
+
rotary_cos, rotary_sin = None, None
|
| 290 |
+
batch = q.shape[0]
|
| 291 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 292 |
+
cache_seqlens = (
|
| 293 |
+
inference_params.lengths_per_sample[:batch]
|
| 294 |
+
if inference_params.lengths_per_sample is not None
|
| 295 |
+
else inference_params.seqlen_offset
|
| 296 |
+
)
|
| 297 |
+
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 298 |
+
|
| 299 |
+
context = flash_attn_with_kvcache(
|
| 300 |
+
q,
|
| 301 |
+
kv_cache[:, :, 0],
|
| 302 |
+
kv_cache[:, :, 1],
|
| 303 |
+
kv[:, :, 0],
|
| 304 |
+
kv[:, :, 1],
|
| 305 |
+
rotary_cos=rotary_cos,
|
| 306 |
+
rotary_sin=rotary_sin,
|
| 307 |
+
cache_seqlens=cache_seqlens,
|
| 308 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 309 |
+
causal=self.inner_cross_attn.causal,
|
| 310 |
+
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
| 311 |
+
alibi_slopes=alibi_slopes,
|
| 312 |
+
)
|
| 313 |
+
return context
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _update_kvcache_attention_triton(self, q, kv, inference_params, batch_head_idx=None):
|
| 317 |
+
"""
|
| 318 |
+
The rotary embeddings have to be applied before calling this function. The KV cache is update here.
|
| 319 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
| 320 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
| 321 |
+
"""
|
| 322 |
+
if (
|
| 323 |
+
inference_params.seqlen_offset == 0
|
| 324 |
+
or flash_attn_with_kvcache is None
|
| 325 |
+
or not self.use_flash_attn
|
| 326 |
+
):
|
| 327 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
| 328 |
+
kv = self._update_kv_cache(kv, inference_params)
|
| 329 |
+
return self.inner_cross_attn(q, kv)
|
| 330 |
+
else:
|
| 331 |
+
batch = q.shape[0]
|
| 332 |
+
kv_cache = self._update_kv_cache(kv, inference_params)
|
| 333 |
+
|
| 334 |
+
cache_seqlens = (
|
| 335 |
+
inference_params.lengths_per_sample[:batch]
|
| 336 |
+
if inference_params.lengths_per_sample is not None
|
| 337 |
+
else inference_params.seqlen_offset
|
| 338 |
+
)
|
| 339 |
+
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 340 |
+
|
| 341 |
+
context = flash_attn_with_kvcache_triton(
|
| 342 |
+
q,
|
| 343 |
+
kv_cache[:, :, 0],
|
| 344 |
+
kv_cache[:, :, 1],
|
| 345 |
+
None, # kv[:, :, 0],
|
| 346 |
+
None, #kv[:, :, 1],
|
| 347 |
+
rotary_cos=None,
|
| 348 |
+
rotary_sin=None,
|
| 349 |
+
cache_seqlens=cache_seqlens,
|
| 350 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 351 |
+
causal=self.inner_cross_attn.causal,
|
| 352 |
+
rotary_interleaved= False, #self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
| 353 |
+
alibi_slopes=alibi_slopes,
|
| 354 |
+
batch_head_idx=batch_head_idx,
|
| 355 |
+
)
|
| 356 |
+
return context
|
| 357 |
+
|
| 358 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
| 359 |
+
"""Write kv to inference_params, then do attention"""
|
| 360 |
+
if (
|
| 361 |
+
inference_params.seqlen_offset == 0
|
| 362 |
+
or flash_attn_with_kvcache is None
|
| 363 |
+
or not self.use_flash_attn
|
| 364 |
+
):
|
| 365 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
| 366 |
+
kv = self._update_kv_cache(kv, inference_params)
|
| 367 |
+
return self.inner_cross_attn(q, kv)
|
| 368 |
+
else:
|
| 369 |
+
batch = q.shape[0]
|
| 370 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 371 |
+
cache_seqlens = (
|
| 372 |
+
inference_params.lengths_per_sample[:batch]
|
| 373 |
+
if inference_params.lengths_per_sample is not None
|
| 374 |
+
else inference_params.seqlen_offset
|
| 375 |
+
)
|
| 376 |
+
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 377 |
+
return flash_attn_with_kvcache(
|
| 378 |
+
q,
|
| 379 |
+
kv_cache[:, :, 0],
|
| 380 |
+
kv_cache[:, :, 1],
|
| 381 |
+
kv[:, :, 0],
|
| 382 |
+
kv[:, :, 1],
|
| 383 |
+
cache_seqlens=cache_seqlens,
|
| 384 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 385 |
+
causal=self.inner_cross_attn.causal,
|
| 386 |
+
alibi_slopes=alibi_slopes,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
def forward(
|
| 390 |
+
self,
|
| 391 |
+
x,
|
| 392 |
+
x_kv=None,
|
| 393 |
+
key_padding_mask=None,
|
| 394 |
+
cu_seqlens=None,
|
| 395 |
+
max_seqlen=None,
|
| 396 |
+
mixer_subset=None,
|
| 397 |
+
inference_params=None,
|
| 398 |
+
batch_head_idx=None,
|
| 399 |
+
use_triton=True,
|
| 400 |
+
**kwargs,
|
| 401 |
+
):
|
| 402 |
+
"""
|
| 403 |
+
Arguments:
|
| 404 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
| 405 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
| 406 |
+
is the is the sum of the sequence lengths in the batch.
|
| 407 |
+
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
|
| 408 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 409 |
+
of the sequences in the batch, used to index into x. Only applicable when using
|
| 410 |
+
FlashAttention.
|
| 411 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
| 412 |
+
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
| 413 |
+
(batch, seqlen). Only applicable when not using FlashAttention.
|
| 414 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
| 415 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
| 416 |
+
about the CLS token in the last layer.
|
| 417 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
| 418 |
+
batch_head_idx: (batch, num_heads). The index of the heads to be selected. Only applicable for Selective Head/Group Attention.
|
| 419 |
+
use_triton: whether to use triton kernels for attention in decode. If False, use the original flash attention implementation.
|
| 420 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
| 421 |
+
"""
|
| 422 |
+
if cu_seqlens is not None:
|
| 423 |
+
assert max_seqlen is not None
|
| 424 |
+
assert key_padding_mask is None
|
| 425 |
+
assert self.use_flash_attn
|
| 426 |
+
assert not self.dwconv
|
| 427 |
+
assert self.rotary_emb_dim == 0
|
| 428 |
+
if key_padding_mask is not None:
|
| 429 |
+
assert cu_seqlens is None
|
| 430 |
+
assert max_seqlen is None
|
| 431 |
+
assert not self.use_flash_attn
|
| 432 |
+
if inference_params is not None:
|
| 433 |
+
assert key_padding_mask is None
|
| 434 |
+
assert cu_seqlens is None and max_seqlen is None
|
| 435 |
+
assert not self.dwconv
|
| 436 |
+
# use_triton = self.use_triton if use_triton is None else use_triton
|
| 437 |
+
kwargs = (
|
| 438 |
+
{"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
|
| 439 |
+
if self.use_flash_attn
|
| 440 |
+
else {"key_padding_mask": key_padding_mask, **kwargs}
|
| 441 |
+
)
|
| 442 |
+
seqlen_offset = (
|
| 443 |
+
0
|
| 444 |
+
if inference_params is None
|
| 445 |
+
else (
|
| 446 |
+
inference_params.lengths_per_sample
|
| 447 |
+
if inference_params.lengths_per_sample is not None
|
| 448 |
+
else inference_params.seqlen_offset
|
| 449 |
+
)
|
| 450 |
+
)
|
| 451 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
| 452 |
+
batch, seqlen = x.shape[:2]
|
| 453 |
+
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 454 |
+
assert x_kv is None and mixer_subset is None
|
| 455 |
+
if not self.return_residual:
|
| 456 |
+
qkv = self.Wqkv(x)
|
| 457 |
+
else:
|
| 458 |
+
qkv, x = self.Wqkv(x)
|
| 459 |
+
if self.dwconv:
|
| 460 |
+
qkv = rearrange(
|
| 461 |
+
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
| 462 |
+
).contiguous()
|
| 463 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
| 464 |
+
if (
|
| 465 |
+
inference_params is None
|
| 466 |
+
or inference_params.seqlen_offset == 0
|
| 467 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 468 |
+
or not self.use_flash_attn
|
| 469 |
+
):
|
| 470 |
+
# prefill stage
|
| 471 |
+
if self.rotary_emb_dim > 0:
|
| 472 |
+
qkv = self.rotary_emb(
|
| 473 |
+
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 474 |
+
)
|
| 475 |
+
if inference_params is None:
|
| 476 |
+
if not self.checkpointing:
|
| 477 |
+
context = self.inner_attn(qkv, **kwargs)
|
| 478 |
+
else:
|
| 479 |
+
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
| 480 |
+
else:
|
| 481 |
+
if use_triton:
|
| 482 |
+
# print("Using the (prefill) triton flash attention implementation")
|
| 483 |
+
context = self._update_kvcache_attention_triton(
|
| 484 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx
|
| 485 |
+
)
|
| 486 |
+
else:
|
| 487 |
+
# print("Using the (prefill) original flash attention implementation")
|
| 488 |
+
context = self._update_kvcache_attention(
|
| 489 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
| 490 |
+
)
|
| 491 |
+
else:
|
| 492 |
+
# decode stage
|
| 493 |
+
# print("Using triton kernels for attention")
|
| 494 |
+
if use_triton:
|
| 495 |
+
if self.rotary_emb_dim > 0:
|
| 496 |
+
qkv = self.rotary_emb(
|
| 497 |
+
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 498 |
+
)
|
| 499 |
+
context = self._update_kvcache_attention_triton(
|
| 500 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
# print("Using the original flash attention implementation")
|
| 504 |
+
context = self._apply_rotary_update_kvcache_attention(
|
| 505 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
else: # cross-attention or MQA/GQA
|
| 509 |
+
if self.cross_attn:
|
| 510 |
+
if not self.return_residual:
|
| 511 |
+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
| 512 |
+
kv = self.Wkv(x_kv if x_kv is not None else x)
|
| 513 |
+
else:
|
| 514 |
+
if x_kv is not None:
|
| 515 |
+
kv, x_kv = self.Wkv(x_kv)
|
| 516 |
+
else:
|
| 517 |
+
kv, x = self.Wkv(x)
|
| 518 |
+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
| 519 |
+
else:
|
| 520 |
+
assert self.num_heads_kv != self.num_heads
|
| 521 |
+
if not self.return_residual:
|
| 522 |
+
qkv = self.Wqkv(x)
|
| 523 |
+
else:
|
| 524 |
+
qkv, x = self.Wqkv(x)
|
| 525 |
+
q = qkv[..., : self.num_heads * self.head_dim]
|
| 526 |
+
kv = qkv[..., self.num_heads * self.head_dim :]
|
| 527 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
| 528 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
| 529 |
+
if self.dwconv:
|
| 530 |
+
q = rearrange(
|
| 531 |
+
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
| 532 |
+
).contiguous()
|
| 533 |
+
kv = rearrange(
|
| 534 |
+
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
| 535 |
+
).contiguous()
|
| 536 |
+
if (
|
| 537 |
+
inference_params is None
|
| 538 |
+
or inference_params.seqlen_offset == 0
|
| 539 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 540 |
+
or not self.use_flash_attn
|
| 541 |
+
):
|
| 542 |
+
# prefill
|
| 543 |
+
if self.rotary_emb_dim > 0:
|
| 544 |
+
q, kv = self.rotary_emb(
|
| 545 |
+
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 546 |
+
)
|
| 547 |
+
if inference_params is None:
|
| 548 |
+
if not self.checkpointing:
|
| 549 |
+
context = self.inner_cross_attn(q, kv, **kwargs)
|
| 550 |
+
else:
|
| 551 |
+
context = torch.utils.checkpoint.checkpoint(
|
| 552 |
+
self.inner_cross_attn, q, kv, **kwargs
|
| 553 |
+
)
|
| 554 |
+
else:
|
| 555 |
+
if use_triton:
|
| 556 |
+
context = self._update_kvcache_attention_triton(
|
| 557 |
+
q, kv, inference_params, batch_head_idx
|
| 558 |
+
)
|
| 559 |
+
else:
|
| 560 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 561 |
+
else:
|
| 562 |
+
# decode
|
| 563 |
+
# print("Using triton kernels for attention")
|
| 564 |
+
if use_triton:
|
| 565 |
+
if self.rotary_emb_dim > 0:
|
| 566 |
+
q, kv = self.rotary_emb(
|
| 567 |
+
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 568 |
+
)
|
| 569 |
+
context = self._update_kvcache_attention_triton(
|
| 570 |
+
q, kv, inference_params, batch_head_idx
|
| 571 |
+
)
|
| 572 |
+
else:
|
| 573 |
+
# print("Using the original gqa flash attention implementation")
|
| 574 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 575 |
+
# print(f"Context.shape: {context.shape}")
|
| 576 |
+
# print(f"Context: {context}")
|
| 577 |
+
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
| 578 |
+
return out if not self.return_residual else (out, x)
|
| 579 |
+
|
| 580 |
+
class ParallelSMHA(nn.Module):
|
| 581 |
+
"""Multi-head self-attention and cross-attention"""
|
| 582 |
+
|
| 583 |
+
def __init__(
|
| 584 |
+
self,
|
| 585 |
+
embed_dim,
|
| 586 |
+
num_heads,
|
| 587 |
+
process_group,
|
| 588 |
+
num_heads_kv=None,
|
| 589 |
+
qkv_proj_bias=True,
|
| 590 |
+
out_proj_bias=True,
|
| 591 |
+
dropout=0.0,
|
| 592 |
+
softmax_scale=None,
|
| 593 |
+
causal=False,
|
| 594 |
+
layer_idx=None,
|
| 595 |
+
rotary_emb_dim=0,
|
| 596 |
+
rotary_emb_base=10000.0,
|
| 597 |
+
rotary_emb_scale_base=None,
|
| 598 |
+
rotary_emb_interleaved=False,
|
| 599 |
+
use_alibi=False,
|
| 600 |
+
window_size=(-1, -1),
|
| 601 |
+
use_flash_attn=False,
|
| 602 |
+
checkpointing=False,
|
| 603 |
+
sequence_parallel=True,
|
| 604 |
+
device=None,
|
| 605 |
+
dtype=None,
|
| 606 |
+
) -> None:
|
| 607 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 608 |
+
super().__init__()
|
| 609 |
+
self.embed_dim = embed_dim
|
| 610 |
+
self.causal = causal
|
| 611 |
+
self.layer_idx = layer_idx
|
| 612 |
+
self.rotary_emb_dim = rotary_emb_dim
|
| 613 |
+
self.use_flash_attn = use_flash_attn
|
| 614 |
+
self.checkpointing = checkpointing
|
| 615 |
+
self.process_group = process_group
|
| 616 |
+
self.world_size = process_group.size()
|
| 617 |
+
self.local_rank = torch.distributed.get_rank(process_group)
|
| 618 |
+
|
| 619 |
+
self.num_heads = num_heads
|
| 620 |
+
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 621 |
+
|
| 622 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 623 |
+
assert (
|
| 624 |
+
self.num_heads % self.num_heads_kv == 0
|
| 625 |
+
), "num_heads must be divisible by num_heads_kv"
|
| 626 |
+
|
| 627 |
+
self.num_heads_per_rank = get_dim_for_local_rank(
|
| 628 |
+
self.num_heads, self.world_size, self.local_rank
|
| 629 |
+
)
|
| 630 |
+
self.num_heads_kv_per_rank = get_dim_for_local_rank(
|
| 631 |
+
self.num_heads_kv, self.world_size, self.local_rank
|
| 632 |
+
)
|
| 633 |
+
self.head_dim = self.embed_dim // num_heads
|
| 634 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 635 |
+
|
| 636 |
+
if use_alibi:
|
| 637 |
+
assert use_flash_attn, "ALiBi code path requires flash_attn"
|
| 638 |
+
num_heads_local = math.ceil(self.num_heads / self.world_size)
|
| 639 |
+
alibi_slopes = torch.tensor(
|
| 640 |
+
get_alibi_slopes(num_heads)[
|
| 641 |
+
self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
|
| 642 |
+
],
|
| 643 |
+
device=device,
|
| 644 |
+
)
|
| 645 |
+
else:
|
| 646 |
+
alibi_slopes = None
|
| 647 |
+
if window_size != (-1, -1):
|
| 648 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 649 |
+
|
| 650 |
+
if self.rotary_emb_dim > 0:
|
| 651 |
+
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
| 652 |
+
self.rotary_emb = RotaryEmbedding(
|
| 653 |
+
self.rotary_emb_dim,
|
| 654 |
+
base=rotary_emb_base,
|
| 655 |
+
scale_base=rotary_emb_scale_base,
|
| 656 |
+
interleaved=rotary_emb_interleaved,
|
| 657 |
+
device=device,
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
if ColumnParallelLinear is None or RowParallelLinear is None:
|
| 661 |
+
raise ImportError("fused_dense is not installed")
|
| 662 |
+
self.Wqkv = ColumnParallelLinear(
|
| 663 |
+
embed_dim,
|
| 664 |
+
qkv_dim,
|
| 665 |
+
process_group,
|
| 666 |
+
bias=qkv_proj_bias,
|
| 667 |
+
sequence_parallel=sequence_parallel,
|
| 668 |
+
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
|
| 669 |
+
**factory_kwargs,
|
| 670 |
+
)
|
| 671 |
+
inner_attn_cls = (
|
| 672 |
+
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 673 |
+
if use_flash_attn
|
| 674 |
+
else SelfAttention
|
| 675 |
+
)
|
| 676 |
+
inner_cross_attn_cls = (
|
| 677 |
+
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 678 |
+
if use_flash_attn
|
| 679 |
+
else CrossAttention
|
| 680 |
+
)
|
| 681 |
+
self.inner_attn = inner_attn_cls(
|
| 682 |
+
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 683 |
+
)
|
| 684 |
+
self.inner_cross_attn = inner_cross_attn_cls(
|
| 685 |
+
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 686 |
+
)
|
| 687 |
+
self.out_proj = RowParallelLinear(
|
| 688 |
+
embed_dim,
|
| 689 |
+
embed_dim,
|
| 690 |
+
process_group,
|
| 691 |
+
bias=out_proj_bias,
|
| 692 |
+
sequence_parallel=sequence_parallel,
|
| 693 |
+
multiple_of=self.head_dim,
|
| 694 |
+
**factory_kwargs,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 698 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
| 699 |
+
device = self.out_proj.weight.device
|
| 700 |
+
return torch.empty(
|
| 701 |
+
batch_size,
|
| 702 |
+
max_seqlen,
|
| 703 |
+
2,
|
| 704 |
+
self.num_heads_kv_per_rank,
|
| 705 |
+
self.head_dim,
|
| 706 |
+
dtype=dtype,
|
| 707 |
+
device=device,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
def _update_kv_cache(self, kv, inference_params):
|
| 711 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 712 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 713 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 714 |
+
|
| 715 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
| 716 |
+
"""
|
| 717 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
| 718 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
| 719 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
| 720 |
+
"""
|
| 721 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
| 722 |
+
assert self.use_flash_attn
|
| 723 |
+
if self.rotary_emb_dim > 0:
|
| 724 |
+
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
| 725 |
+
self.rotary_emb._update_cos_sin_cache(
|
| 726 |
+
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
| 727 |
+
)
|
| 728 |
+
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
| 729 |
+
else:
|
| 730 |
+
rotary_cos, rotary_sin = None, None
|
| 731 |
+
batch = q.shape[0]
|
| 732 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 733 |
+
cache_seqlens = (
|
| 734 |
+
inference_params.lengths_per_sample[:batch]
|
| 735 |
+
if inference_params.lengths_per_sample is not None
|
| 736 |
+
else inference_params.seqlen_offset
|
| 737 |
+
)
|
| 738 |
+
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 739 |
+
context = flash_attn_with_kvcache(
|
| 740 |
+
q,
|
| 741 |
+
kv_cache[:, :, 0],
|
| 742 |
+
kv_cache[:, :, 1],
|
| 743 |
+
kv[:, :, 0],
|
| 744 |
+
kv[:, :, 1],
|
| 745 |
+
rotary_cos=rotary_cos,
|
| 746 |
+
rotary_sin=rotary_sin,
|
| 747 |
+
cache_seqlens=cache_seqlens,
|
| 748 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 749 |
+
causal=self.inner_cross_attn.causal,
|
| 750 |
+
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
| 751 |
+
alibi_slopes=alibi_slopes,
|
| 752 |
+
)
|
| 753 |
+
return context
|
| 754 |
+
|
| 755 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
| 756 |
+
"""Write kv to inference_params, then do attention"""
|
| 757 |
+
if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
|
| 758 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
| 759 |
+
kv = self._update_kv_cache(kv, inference_params)
|
| 760 |
+
return self.inner_cross_attn(q, kv)
|
| 761 |
+
else:
|
| 762 |
+
batch = q.shape[0]
|
| 763 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 764 |
+
cache_seqlens = (
|
| 765 |
+
inference_params.lengths_per_sample[:batch]
|
| 766 |
+
if inference_params.lengths_per_sample is not None
|
| 767 |
+
else inference_params.seqlen_offset
|
| 768 |
+
)
|
| 769 |
+
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 770 |
+
context = flash_attn_with_kvcache(
|
| 771 |
+
q,
|
| 772 |
+
kv_cache[:, :, 0],
|
| 773 |
+
kv_cache[:, :, 1],
|
| 774 |
+
kv[:, :, 0],
|
| 775 |
+
kv[:, :, 1],
|
| 776 |
+
cache_seqlens=cache_seqlens,
|
| 777 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 778 |
+
causal=self.inner_cross_attn.causal,
|
| 779 |
+
alibi_slopes=alibi_slopes,
|
| 780 |
+
)
|
| 781 |
+
return context
|
| 782 |
+
|
| 783 |
+
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
| 784 |
+
"""
|
| 785 |
+
Arguments:
|
| 786 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
| 787 |
+
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
|
| 788 |
+
split x during sequence parallel, we split the batch * seqlen dimension
|
| 789 |
+
(in case batch is small).
|
| 790 |
+
"""
|
| 791 |
+
qkv = self.Wqkv(x)
|
| 792 |
+
if seqlen is not None:
|
| 793 |
+
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
| 794 |
+
seqlen_offset = (
|
| 795 |
+
0
|
| 796 |
+
if inference_params is None
|
| 797 |
+
else (
|
| 798 |
+
inference_params.lengths_per_sample
|
| 799 |
+
if inference_params.lengths_per_sample is not None
|
| 800 |
+
else inference_params.seqlen_offset
|
| 801 |
+
)
|
| 802 |
+
)
|
| 803 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
| 804 |
+
if self.num_heads_kv == self.num_heads:
|
| 805 |
+
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
| 806 |
+
if (
|
| 807 |
+
inference_params is None
|
| 808 |
+
or inference_params.seqlen_offset == 0
|
| 809 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 810 |
+
or not self.use_flash_attn
|
| 811 |
+
):
|
| 812 |
+
if self.rotary_emb_dim > 0:
|
| 813 |
+
qkv = self.rotary_emb(
|
| 814 |
+
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 815 |
+
)
|
| 816 |
+
if inference_params is None:
|
| 817 |
+
if not self.checkpointing:
|
| 818 |
+
context = self.inner_attn(qkv, **kwargs)
|
| 819 |
+
else:
|
| 820 |
+
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
| 821 |
+
else:
|
| 822 |
+
context = self._update_kvcache_attention(
|
| 823 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
| 824 |
+
)
|
| 825 |
+
else:
|
| 826 |
+
context = self._apply_rotary_update_kvcache_attention(
|
| 827 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
| 828 |
+
)
|
| 829 |
+
else: # GQA/MQA
|
| 830 |
+
q = rearrange(
|
| 831 |
+
qkv[..., : self.num_heads_per_rank * self.head_dim],
|
| 832 |
+
"... (h d) -> ... h d",
|
| 833 |
+
d=self.head_dim,
|
| 834 |
+
)
|
| 835 |
+
kv = rearrange(
|
| 836 |
+
qkv[..., self.num_heads_per_rank * self.head_dim :],
|
| 837 |
+
"... (two hkv d) -> ... two hkv d",
|
| 838 |
+
two=2,
|
| 839 |
+
d=self.head_dim,
|
| 840 |
+
)
|
| 841 |
+
if (
|
| 842 |
+
inference_params is None
|
| 843 |
+
or inference_params.seqlen_offset == 0
|
| 844 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 845 |
+
or not self.use_flash_attn
|
| 846 |
+
):
|
| 847 |
+
if self.rotary_emb_dim > 0:
|
| 848 |
+
q, kv = self.rotary_emb(
|
| 849 |
+
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 850 |
+
)
|
| 851 |
+
if inference_params is None:
|
| 852 |
+
if not self.checkpointing:
|
| 853 |
+
context = self.inner_cross_attn(q, kv, **kwargs)
|
| 854 |
+
else:
|
| 855 |
+
context = torch.utils.checkpoint.checkpoint(
|
| 856 |
+
self.inner_cross_attn, q, kv, **kwargs
|
| 857 |
+
)
|
| 858 |
+
else:
|
| 859 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 860 |
+
else:
|
| 861 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 862 |
+
context = rearrange(context, "b s h d -> b s (h d)")
|
| 863 |
+
if seqlen is not None:
|
| 864 |
+
context = rearrange(context, "b s d -> (b s) d")
|
| 865 |
+
out = self.out_proj(context)
|
| 866 |
+
return out
|
| 867 |
+
|
| 868 |
+
class SelectMHA(nn.Module):
|
| 869 |
+
"""Multi-head, Group-query self-attention using select attention"""
|
| 870 |
+
|
| 871 |
+
def __init__(
|
| 872 |
+
self,
|
| 873 |
+
embed_dim,
|
| 874 |
+
num_heads,
|
| 875 |
+
num_heads_kv=None,
|
| 876 |
+
cross_attn=False,
|
| 877 |
+
qkv_proj_bias=True,
|
| 878 |
+
out_proj_bias=True,
|
| 879 |
+
dropout=0.0,
|
| 880 |
+
softmax_scale=None,
|
| 881 |
+
causal=False,
|
| 882 |
+
layer_idx=None,
|
| 883 |
+
dwconv=False,
|
| 884 |
+
rotary_emb_dim=0,
|
| 885 |
+
rotary_emb_base=10000.0,
|
| 886 |
+
rotary_emb_scale_base=None,
|
| 887 |
+
rotary_emb_interleaved=False,
|
| 888 |
+
use_alibi=False,
|
| 889 |
+
window_size=(-1, -1),
|
| 890 |
+
fused_bias_fc=False,
|
| 891 |
+
use_flash_attn=True,
|
| 892 |
+
return_residual=False,
|
| 893 |
+
checkpointing=False,
|
| 894 |
+
device=None,
|
| 895 |
+
dtype=None,
|
| 896 |
+
) -> None:
|
| 897 |
+
"""
|
| 898 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
| 899 |
+
return_residual: whether to return the input x along with the output. This is for
|
| 900 |
+
performance reason: for post-norm architecture, returning the input allows us
|
| 901 |
+
to fuse the backward of nn.Linear with the residual connection.
|
| 902 |
+
"""
|
| 903 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 904 |
+
super().__init__()
|
| 905 |
+
self.embed_dim = embed_dim
|
| 906 |
+
self.cross_attn = cross_attn
|
| 907 |
+
self.causal = causal
|
| 908 |
+
self.layer_idx = layer_idx
|
| 909 |
+
self.dwconv = dwconv
|
| 910 |
+
self.rotary_emb_dim = rotary_emb_dim
|
| 911 |
+
self.use_flash_attn = True # use_flash_attn
|
| 912 |
+
self.return_residual = return_residual
|
| 913 |
+
self.checkpointing = checkpointing
|
| 914 |
+
if use_alibi:
|
| 915 |
+
assert use_flash_attn, "ALiBi code path requires flash_attn"
|
| 916 |
+
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
| 917 |
+
else:
|
| 918 |
+
alibi_slopes = None
|
| 919 |
+
if window_size != (-1, -1):
|
| 920 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 921 |
+
|
| 922 |
+
self.num_heads = num_heads
|
| 923 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 924 |
+
assert (
|
| 925 |
+
self.num_heads % self.num_heads_kv == 0
|
| 926 |
+
), "num_heads must be divisible by num_heads_kv"
|
| 927 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 928 |
+
self.head_dim = self.embed_dim // num_heads
|
| 929 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 930 |
+
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
| 931 |
+
|
| 932 |
+
if self.rotary_emb_dim > 0:
|
| 933 |
+
assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
|
| 934 |
+
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
| 935 |
+
self.rotary_emb = RotaryEmbedding(
|
| 936 |
+
self.rotary_emb_dim,
|
| 937 |
+
base=rotary_emb_base,
|
| 938 |
+
scale_base=rotary_emb_scale_base,
|
| 939 |
+
interleaved=rotary_emb_interleaved,
|
| 940 |
+
device=device,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
if fused_bias_fc and FusedDense is None:
|
| 944 |
+
raise ImportError("fused_dense is not installed")
|
| 945 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 946 |
+
linear_resid_cls = (
|
| 947 |
+
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
| 948 |
+
)
|
| 949 |
+
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
| 950 |
+
inner_attn_cls = (
|
| 951 |
+
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 952 |
+
if use_flash_attn
|
| 953 |
+
else SelfAttention
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 957 |
+
|
| 958 |
+
self.inner_attn = inner_attn_cls(
|
| 959 |
+
causal=causal,
|
| 960 |
+
softmax_scale=softmax_scale,
|
| 961 |
+
attention_dropout=dropout,
|
| 962 |
+
)
|
| 963 |
+
self.softmax_scale = softmax_scale
|
| 964 |
+
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
| 965 |
+
|
| 966 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 967 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
| 968 |
+
device = self.out_proj.weight.device
|
| 969 |
+
return torch.empty(
|
| 970 |
+
batch_size,
|
| 971 |
+
max_seqlen,
|
| 972 |
+
2,
|
| 973 |
+
self.num_heads_kv,
|
| 974 |
+
self.head_dim,
|
| 975 |
+
dtype=dtype,
|
| 976 |
+
device=device,
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
def _update_kv_cache(self, kv, inference_params):
|
| 980 |
+
"""Update kv cache in inference_params."""
|
| 981 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 982 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 983 |
+
|
| 984 |
+
def forward(
|
| 985 |
+
self,
|
| 986 |
+
x,
|
| 987 |
+
x_kv=None,
|
| 988 |
+
key_padding_mask=None,
|
| 989 |
+
cu_seqlens=None,
|
| 990 |
+
max_seqlen=None,
|
| 991 |
+
mixer_subset=None,
|
| 992 |
+
inference_params=None,
|
| 993 |
+
batch_head_idx=None,
|
| 994 |
+
**kwargs,
|
| 995 |
+
):
|
| 996 |
+
"""
|
| 997 |
+
Arguments:
|
| 998 |
+
x: (batch, seqlen, hidden_dim)
|
| 999 |
+
batch_head_idx: Tensor of indices specifying which batch and head indices to select.
|
| 1000 |
+
Shape: (batch_size, top_k)
|
| 1001 |
+
inference_params: for generation.
|
| 1002 |
+
"""
|
| 1003 |
+
seqlen_offset = (
|
| 1004 |
+
0
|
| 1005 |
+
if inference_params is None
|
| 1006 |
+
else (
|
| 1007 |
+
inference_params.lengths_per_sample
|
| 1008 |
+
if inference_params.lengths_per_sample is not None
|
| 1009 |
+
else inference_params.seqlen_offset
|
| 1010 |
+
)
|
| 1011 |
+
)
|
| 1012 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
| 1013 |
+
batch, seqlen = x.shape[:2]
|
| 1014 |
+
|
| 1015 |
+
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 1016 |
+
# Self-attention, no MQA/GQA
|
| 1017 |
+
assert x_kv is None and mixer_subset is None
|
| 1018 |
+
if not self.return_residual:
|
| 1019 |
+
qkv = self.Wqkv(x)
|
| 1020 |
+
else:
|
| 1021 |
+
qkv, x = self.Wqkv(x)
|
| 1022 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
| 1023 |
+
|
| 1024 |
+
if self.rotary_emb_dim > 0:
|
| 1025 |
+
qkv = self.rotary_emb(
|
| 1026 |
+
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
if inference_params is None or inference_params.seqlen_offset == 0:
|
| 1030 |
+
# Inference stage without inference_params
|
| 1031 |
+
if inference_params is not None:
|
| 1032 |
+
# Update kv cache during prefill
|
| 1033 |
+
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
|
| 1034 |
+
|
| 1035 |
+
context = self.inner_attn(qkv, **kwargs)
|
| 1036 |
+
|
| 1037 |
+
else:
|
| 1038 |
+
# Generation stage
|
| 1039 |
+
if batch_head_idx is None:
|
| 1040 |
+
# Apply select attention without kv cache update
|
| 1041 |
+
context = self._update_kvcache_attention(q = qkv[:, :, 0], kv = qkv[:, :, 1:], inference_params = inference_params)
|
| 1042 |
+
else:
|
| 1043 |
+
# Apply select attention with kv cache update
|
| 1044 |
+
context = self._update_kvcache_select_attn(q = qkv[:, :, 0], kv = qkv[:, :, 1:], inference_params = inference_params, batch_head_idx = batch_head_idx)
|
| 1045 |
+
|
| 1046 |
+
else:
|
| 1047 |
+
raise NotImplementedError("SelectMHA currently supports only self-attention without MQA/GQA.")
|
| 1048 |
+
|
| 1049 |
+
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
| 1050 |
+
return out if not self.return_residual else (out, x)
|
| 1051 |
+
|
| 1052 |
+
def _update_kvcache_select_attn(self, q, kv, inference_params, batch_head_idx):
|
| 1053 |
+
"""
|
| 1054 |
+
Apply select attention during generation stage.
|
| 1055 |
+
|
| 1056 |
+
q: (batch_size, seqlen=1, n_heads, head_dim)
|
| 1057 |
+
kv: (batch_size, seqlen=1, 2, n_heads, head_dim)
|
| 1058 |
+
batch_head_idx: Tensor of indices specifying which batch and head indices to select.
|
| 1059 |
+
Shape: (batch_size, top_k)
|
| 1060 |
+
|
| 1061 |
+
# currently only supports batches with same seqlen
|
| 1062 |
+
# different seqlen requires a simple update in the select_attn kernel to load the seqlen, future work
|
| 1063 |
+
"""
|
| 1064 |
+
# check batch_head_idx shape
|
| 1065 |
+
# assert batch_head_idx.shape[0] == 2, "batch_head_idx must have shape (N_selected, 2)"
|
| 1066 |
+
# check batch_head_idx is not None
|
| 1067 |
+
assert batch_head_idx is not None, "batch_head_idx must not be None"
|
| 1068 |
+
|
| 1069 |
+
# update kv cache
|
| 1070 |
+
kv_cache = self._update_kv_cache(kv, inference_params)
|
| 1071 |
+
# inference_params.seqlen_offset += 1 # if seqlen_offset is int
|
| 1072 |
+
|
| 1073 |
+
batch = q.shape[0]
|
| 1074 |
+
# kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 1075 |
+
|
| 1076 |
+
# make sure seqlen_offset accounts for the current token
|
| 1077 |
+
cache_seqlens = (
|
| 1078 |
+
inference_params.lengths_per_sample[:batch]
|
| 1079 |
+
if inference_params.lengths_per_sample is not None
|
| 1080 |
+
else inference_params.seqlen_offset + 1 # +1 for the current token
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
# need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim)
|
| 1084 |
+
q = q.unsqueeze(2)
|
| 1085 |
+
k_cache = kv_cache[:, :, 0].unsqueeze(2)
|
| 1086 |
+
v_cache = kv_cache[:, :, 1].unsqueeze(2)
|
| 1087 |
+
|
| 1088 |
+
# Call select_attn
|
| 1089 |
+
context = select_attn(
|
| 1090 |
+
q,
|
| 1091 |
+
k_cache,
|
| 1092 |
+
v_cache,
|
| 1093 |
+
self.softmax_scale,
|
| 1094 |
+
batch_head_idx,
|
| 1095 |
+
cache_seqlens)
|
| 1096 |
+
|
| 1097 |
+
# context: (batch_size, seqlen_q=1, G=1, H, head_dim)
|
| 1098 |
+
# context = context.squeeze(2) # Remove G dimension
|
| 1099 |
+
batch_size = batch_head_idx.shape[0]
|
| 1100 |
+
context = context.view(batch_size, 1, self.num_heads, self.head_dim)
|
| 1101 |
+
|
| 1102 |
+
return context
|
| 1103 |
+
|
| 1104 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
| 1105 |
+
"""Write kv to inference_params, then do attention"""
|
| 1106 |
+
if (
|
| 1107 |
+
inference_params.seqlen_offset == 0
|
| 1108 |
+
or flash_attn_with_kvcache is None
|
| 1109 |
+
or not self.use_flash_attn
|
| 1110 |
+
):
|
| 1111 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
| 1112 |
+
kv = self._update_kv_cache(kv, inference_params)
|
| 1113 |
+
return self.inner_cross_attn(q, kv)
|
| 1114 |
+
else:
|
| 1115 |
+
batch = q.shape[0]
|
| 1116 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 1117 |
+
cache_seqlens = (
|
| 1118 |
+
inference_params.lengths_per_sample[:batch]
|
| 1119 |
+
if inference_params.lengths_per_sample is not None
|
| 1120 |
+
else inference_params.seqlen_offset
|
| 1121 |
+
)
|
| 1122 |
+
# alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 1123 |
+
alibi_slopes = None
|
| 1124 |
+
return flash_attn_with_kvcache(
|
| 1125 |
+
q,
|
| 1126 |
+
kv_cache[:, :, 0],
|
| 1127 |
+
kv_cache[:, :, 1],
|
| 1128 |
+
kv[:, :, 0],
|
| 1129 |
+
kv[:, :, 1],
|
| 1130 |
+
cache_seqlens=cache_seqlens,
|
| 1131 |
+
softmax_scale=self.inner_attn.softmax_scale,
|
| 1132 |
+
causal=self.inner_attn.causal,
|
| 1133 |
+
alibi_slopes=alibi_slopes,
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
# dummy function for testing
|
| 1137 |
+
def _select_attn(self, q, kv, inference_params, batch_head_idx):
|
| 1138 |
+
"""
|
| 1139 |
+
Apply select attention during generation stage.
|
| 1140 |
+
|
| 1141 |
+
q: (batch_size, seqlen=1, n_heads, head_dim)
|
| 1142 |
+
kv: (batch_size, seqlen=1, 2, n_heads, head_dim)
|
| 1143 |
+
batch_head_idx: Tensor of indices specifying which batch and head indices to select.
|
| 1144 |
+
Shape: (N_selected, 2)
|
| 1145 |
+
|
| 1146 |
+
# currently only supports batches with same seqlen
|
| 1147 |
+
# different seqlen requires a simple update in the select_attn kernel to load the seqlen, future work
|
| 1148 |
+
"""
|
| 1149 |
+
# check batch_head_idx shape
|
| 1150 |
+
assert batch_head_idx.shape[1] == 2, "batch_head_idx must have shape (N_selected, 2)"
|
| 1151 |
+
|
| 1152 |
+
# update kv cache
|
| 1153 |
+
# kv_cache = self._update_kv_cache(kv, inference_params)
|
| 1154 |
+
# inference_params.seqlen_offset += 1 # if seqlen_offset is int
|
| 1155 |
+
|
| 1156 |
+
batch = q.shape[0]
|
| 1157 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 1158 |
+
|
| 1159 |
+
# make sure seqlen_offset accounts for the current token
|
| 1160 |
+
cache_seqlens = (
|
| 1161 |
+
inference_params.lengths_per_sample[:batch]
|
| 1162 |
+
if inference_params.lengths_per_sample is not None
|
| 1163 |
+
else inference_params.seqlen_offset # +1 for the current token
|
| 1164 |
+
)
|
| 1165 |
+
|
| 1166 |
+
# need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim)
|
| 1167 |
+
q = q.unsqueeze(2)
|
| 1168 |
+
k_cache = kv_cache[:, :, 0].unsqueeze(2)
|
| 1169 |
+
v_cache = kv_cache[:, :, 1].unsqueeze(2)
|
| 1170 |
+
|
| 1171 |
+
# Call select_attn
|
| 1172 |
+
context = select_attn(
|
| 1173 |
+
q,
|
| 1174 |
+
k_cache,
|
| 1175 |
+
v_cache,
|
| 1176 |
+
self.softmax_scale,
|
| 1177 |
+
batch_head_idx,
|
| 1178 |
+
cache_seqlens)
|
| 1179 |
+
|
| 1180 |
+
# context: (batch_size, seqlen_q=1, G=1, H, head_dim)
|
| 1181 |
+
context = context.squeeze(2) # Remove G dimension
|
| 1182 |
+
|
| 1183 |
+
return context
|
| 1184 |
+
|
| 1185 |
+
|
| 1186 |
+
# SelectiveGQA: Future work
|
| 1187 |
+
|
| 1188 |
+
class ParallelSelectMHA(nn.Module):
|
| 1189 |
+
def __init__(
|
| 1190 |
+
self,
|
| 1191 |
+
embed_dim,
|
| 1192 |
+
num_heads,
|
| 1193 |
+
process_group,
|
| 1194 |
+
num_heads_kv=None,
|
| 1195 |
+
qkv_proj_bias=True,
|
| 1196 |
+
out_proj_bias=True,
|
| 1197 |
+
dropout=0.0,
|
| 1198 |
+
softmax_scale=None,
|
| 1199 |
+
causal=True,
|
| 1200 |
+
layer_idx=None,
|
| 1201 |
+
dwconv=False,
|
| 1202 |
+
rotary_emb_dim=0,
|
| 1203 |
+
rotary_emb_base=10000.0,
|
| 1204 |
+
rotary_emb_scale_base=None,
|
| 1205 |
+
rotary_emb_interleaved=False,
|
| 1206 |
+
use_alibi=False,
|
| 1207 |
+
window_size=(-1, -1),
|
| 1208 |
+
fused_bias_fc=True,
|
| 1209 |
+
use_flash_attn=True,
|
| 1210 |
+
return_residual=False,
|
| 1211 |
+
checkpointing=False,
|
| 1212 |
+
sequence_parallel=False,
|
| 1213 |
+
device=None,
|
| 1214 |
+
dtype=None,
|
| 1215 |
+
) -> None:
|
| 1216 |
+
"""
|
| 1217 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
| 1218 |
+
return_residual: whether to return the input x along with the output. This is for
|
| 1219 |
+
performance reason: for post-norm architecture, returning the input allows us
|
| 1220 |
+
to fuse the backward of nn.Linear with the residual connection.
|
| 1221 |
+
"""
|
| 1222 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1223 |
+
super().__init__()
|
| 1224 |
+
self.embed_dim = embed_dim
|
| 1225 |
+
self.causal = causal
|
| 1226 |
+
self.layer_idx = layer_idx
|
| 1227 |
+
self.dwconv = dwconv
|
| 1228 |
+
self.rotary_emb_dim = rotary_emb_dim
|
| 1229 |
+
self.use_flash_attn = use_flash_attn
|
| 1230 |
+
self.return_residual = return_residual
|
| 1231 |
+
self.checkpointing = checkpointing
|
| 1232 |
+
self.process_group = process_group
|
| 1233 |
+
self.world_size = process_group.size()
|
| 1234 |
+
self.local_rank = torch.distributed.get_rank(process_group)
|
| 1235 |
+
|
| 1236 |
+
self.num_heads = num_heads
|
| 1237 |
+
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 1238 |
+
|
| 1239 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 1240 |
+
assert (
|
| 1241 |
+
self.num_heads % self.num_heads_kv == 0
|
| 1242 |
+
), "num_heads must be divisible by num_heads_kv"
|
| 1243 |
+
|
| 1244 |
+
self.num_heads_per_rank = get_dim_for_local_rank(
|
| 1245 |
+
self.num_heads, self.world_size, self.local_rank
|
| 1246 |
+
)
|
| 1247 |
+
self.num_heads_kv_per_rank = get_dim_for_local_rank(
|
| 1248 |
+
self.num_heads_kv, self.world_size, self.local_rank
|
| 1249 |
+
)
|
| 1250 |
+
self.head_dim = self.embed_dim // num_heads
|
| 1251 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 1252 |
+
|
| 1253 |
+
if use_alibi:
|
| 1254 |
+
assert use_flash_attn, "ALiBi code path requires flash_attn"
|
| 1255 |
+
num_heads_local = math.ceil(self.num_heads / self.world_size)
|
| 1256 |
+
alibi_slopes = torch.tensor(
|
| 1257 |
+
get_alibi_slopes(num_heads)[
|
| 1258 |
+
self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
|
| 1259 |
+
],
|
| 1260 |
+
device=device,
|
| 1261 |
+
)
|
| 1262 |
+
else:
|
| 1263 |
+
alibi_slopes = None
|
| 1264 |
+
if window_size != (-1, -1):
|
| 1265 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 1266 |
+
|
| 1267 |
+
if self.rotary_emb_dim > 0:
|
| 1268 |
+
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
| 1269 |
+
self.rotary_emb = RotaryEmbedding(
|
| 1270 |
+
self.rotary_emb_dim,
|
| 1271 |
+
base=rotary_emb_base,
|
| 1272 |
+
scale_base=rotary_emb_scale_base,
|
| 1273 |
+
interleaved=rotary_emb_interleaved,
|
| 1274 |
+
device=device,
|
| 1275 |
+
)
|
| 1276 |
+
|
| 1277 |
+
if ColumnParallelLinear is None or RowParallelLinear is None:
|
| 1278 |
+
raise ImportError("fused_dense is not installed")
|
| 1279 |
+
self.Wqkv = ColumnParallelLinear(
|
| 1280 |
+
embed_dim,
|
| 1281 |
+
qkv_dim,
|
| 1282 |
+
process_group,
|
| 1283 |
+
bias=qkv_proj_bias,
|
| 1284 |
+
sequence_parallel=sequence_parallel,
|
| 1285 |
+
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
|
| 1286 |
+
**factory_kwargs,
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
inner_attn_cls = (
|
| 1290 |
+
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 1291 |
+
if use_flash_attn
|
| 1292 |
+
else SelfAttention
|
| 1293 |
+
)
|
| 1294 |
+
|
| 1295 |
+
self.inner_attn = inner_attn_cls(
|
| 1296 |
+
causal=causal,
|
| 1297 |
+
softmax_scale=softmax_scale,
|
| 1298 |
+
attention_dropout=dropout,
|
| 1299 |
+
)
|
| 1300 |
+
self.softmax_scale = softmax_scale
|
| 1301 |
+
|
| 1302 |
+
# replace this with no reduce
|
| 1303 |
+
self.out_proj = RowParallelLinear(
|
| 1304 |
+
embed_dim,
|
| 1305 |
+
embed_dim,
|
| 1306 |
+
process_group,
|
| 1307 |
+
bias=out_proj_bias,
|
| 1308 |
+
sequence_parallel=sequence_parallel,
|
| 1309 |
+
multiple_of=self.head_dim,
|
| 1310 |
+
**factory_kwargs,
|
| 1311 |
+
)
|
| 1312 |
+
|
| 1313 |
+
self.mha_router = None
|
| 1314 |
+
self.mlp_router = None
|
| 1315 |
+
# We'll use an extra stream for concurrency
|
| 1316 |
+
self.current_stream = None
|
| 1317 |
+
self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0)
|
| 1318 |
+
self.main_stream = torch.cuda.Stream(device="cuda", priority=-5)
|
| 1319 |
+
self.mha_router_event = torch.cuda.Event(enable_timing=False, blocking=False)
|
| 1320 |
+
self.mlp_router_event = torch.cuda.Event(enable_timing=False, blocking=False)
|
| 1321 |
+
self.main_event = torch.cuda.Event(enable_timing=False, blocking=False)
|
| 1322 |
+
|
| 1323 |
+
# self.local_head_idx = generate_random_BH_index(1, self.num_heads_per_rank,self.num_heads_per_rank)
|
| 1324 |
+
|
| 1325 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 1326 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
| 1327 |
+
device = self.out_proj.weight.device
|
| 1328 |
+
return torch.empty(
|
| 1329 |
+
batch_size,
|
| 1330 |
+
max_seqlen,
|
| 1331 |
+
2,
|
| 1332 |
+
self.num_heads_kv_per_rank,
|
| 1333 |
+
self.head_dim,
|
| 1334 |
+
dtype=dtype,
|
| 1335 |
+
device=device,
|
| 1336 |
+
)
|
| 1337 |
+
|
| 1338 |
+
def _update_kv_cache(self, kv, inference_params):
|
| 1339 |
+
"""Update kv cache in inference_params."""
|
| 1340 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 1341 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 1342 |
+
|
| 1343 |
+
def forward(
|
| 1344 |
+
self,
|
| 1345 |
+
x,
|
| 1346 |
+
seqlen=None,
|
| 1347 |
+
inference_params=None,
|
| 1348 |
+
batch_head_idx=None,
|
| 1349 |
+
**kwargs,
|
| 1350 |
+
):
|
| 1351 |
+
"""
|
| 1352 |
+
Arguments:
|
| 1353 |
+
x: (batch, seqlen, hidden_dim)
|
| 1354 |
+
batch_head_idx: Tensor of indices specifying which batch and head indices to select.
|
| 1355 |
+
Shape: (N_selected,)
|
| 1356 |
+
inference_params: for generation.
|
| 1357 |
+
"""
|
| 1358 |
+
|
| 1359 |
+
router_inputs = x.squeeze(1)
|
| 1360 |
+
self.current_stream = torch.cuda.current_stream()
|
| 1361 |
+
self.main_stream.wait_stream(self.current_stream )
|
| 1362 |
+
self.sparse_stream.wait_stream(self.current_stream )
|
| 1363 |
+
|
| 1364 |
+
is_decode = inference_params is not None and inference_params.seqlen_offset > 0
|
| 1365 |
+
|
| 1366 |
+
# if self.mha_router and is_decode:
|
| 1367 |
+
# with torch.cuda.stream(self.sparse_stream):
|
| 1368 |
+
# batch_head_idx = self.mha_router._select_heads(router_inputs)
|
| 1369 |
+
# self.sparse_stream.record_event(self.mha_router_event)
|
| 1370 |
+
|
| 1371 |
+
|
| 1372 |
+
qkv = self.Wqkv(x)
|
| 1373 |
+
if seqlen is not None:
|
| 1374 |
+
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
| 1375 |
+
|
| 1376 |
+
seqlen_offset = (
|
| 1377 |
+
0
|
| 1378 |
+
if inference_params is None
|
| 1379 |
+
else (
|
| 1380 |
+
inference_params.lengths_per_sample
|
| 1381 |
+
if inference_params.lengths_per_sample is not None
|
| 1382 |
+
else inference_params.seqlen_offset
|
| 1383 |
+
)
|
| 1384 |
+
)
|
| 1385 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
| 1386 |
+
batch, seqlen = x.shape[:2]
|
| 1387 |
+
|
| 1388 |
+
if self.num_heads_kv == self.num_heads:
|
| 1389 |
+
# Self-attention, no MQA/GQA
|
| 1390 |
+
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
| 1391 |
+
|
| 1392 |
+
if self.rotary_emb_dim > 0:
|
| 1393 |
+
qkv = self.rotary_emb(
|
| 1394 |
+
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 1395 |
+
)
|
| 1396 |
+
if inference_params is None or inference_params.seqlen_offset == 0:
|
| 1397 |
+
# Inference stage without inference_params, prefill stage
|
| 1398 |
+
if inference_params is not None:
|
| 1399 |
+
# Update kv cache during prefill
|
| 1400 |
+
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
|
| 1401 |
+
|
| 1402 |
+
context = self.inner_attn(qkv, **kwargs)
|
| 1403 |
+
else:
|
| 1404 |
+
# Generation stage
|
| 1405 |
+
|
| 1406 |
+
# apply rotary embeddings
|
| 1407 |
+
if self.rotary_emb_dim > 0:
|
| 1408 |
+
qkv = self.rotary_emb(
|
| 1409 |
+
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 1410 |
+
)
|
| 1411 |
+
|
| 1412 |
+
# Apply select attention with kv cache update
|
| 1413 |
+
context = self._update_kvcache_select_attn(qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx)
|
| 1414 |
+
else: # cross-attention, MQA/GQA
|
| 1415 |
+
raise NotImplementedError("SelectMHA currently supports only self-attention without MQA/GQA.")
|
| 1416 |
+
|
| 1417 |
+
context = rearrange(context, "b s h d -> b s (h d)")
|
| 1418 |
+
if seqlen is not None:
|
| 1419 |
+
context = rearrange(context, "b s d -> (b s) d")
|
| 1420 |
+
# out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
| 1421 |
+
# out = self.out_proj(context)
|
| 1422 |
+
|
| 1423 |
+
out = fused_dense_func(context, self.out_proj.weight, self.out_proj.bias)
|
| 1424 |
+
|
| 1425 |
+
# if is_decode:
|
| 1426 |
+
# if self.mlp_router:
|
| 1427 |
+
# with torch.cuda.stream(self.sparse_stream):
|
| 1428 |
+
# index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
|
| 1429 |
+
# self.sparse_stream.record_event(self.mlp_router_event)
|
| 1430 |
+
|
| 1431 |
+
# with torch.cuda.stream(self.main_stream):
|
| 1432 |
+
# out = all_reduce(out, self.process_group)
|
| 1433 |
+
# self.main_stream.record_event(self.main_event)
|
| 1434 |
+
|
| 1435 |
+
# self.current_stream.wait_event(self.mlp_router_event)
|
| 1436 |
+
# self.current_stream.wait_event(self.main_event)
|
| 1437 |
+
|
| 1438 |
+
# # index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
|
| 1439 |
+
# # out = all_reduce(out, self.process_group)
|
| 1440 |
+
|
| 1441 |
+
# return out, index_vec
|
| 1442 |
+
# else:
|
| 1443 |
+
# out = all_reduce(out, self.process_group)
|
| 1444 |
+
out = all_reduce(out, self.process_group)
|
| 1445 |
+
return out if not self.return_residual else (out, x)
|
| 1446 |
+
# return out
|
| 1447 |
+
|
| 1448 |
+
def _update_kvcache_select_attn(self, q, kv, inference_params, batch_head_idx = None):
|
| 1449 |
+
"""
|
| 1450 |
+
Apply select attention during generation stage.
|
| 1451 |
+
|
| 1452 |
+
q: (batch_size, seqlen=1, n_heads, head_dim)
|
| 1453 |
+
kv: (batch_size, seqlen=1, 2, n_heads, head_dim)
|
| 1454 |
+
batch_head_idx: Tensor of indices specifying which batch and head indices to select.
|
| 1455 |
+
Shape: (batch_size, top_k)
|
| 1456 |
+
"""
|
| 1457 |
+
# check batch_head_idx shape
|
| 1458 |
+
# assert batch_head_idx.shape[1] == 2, "batch_head_idx must have shape (N_selected, 2)"
|
| 1459 |
+
|
| 1460 |
+
# if batch_head_idx is None:
|
| 1461 |
+
# batch_head_idx = self.local_head_idx
|
| 1462 |
+
# print("Using local_head_idx, router not used.")
|
| 1463 |
+
# batch_head_idx = self.local_head_idx
|
| 1464 |
+
# print("Using local_head_idx, router not used.")
|
| 1465 |
+
|
| 1466 |
+
# update kv cache
|
| 1467 |
+
kv_cache = self._update_kv_cache(kv, inference_params)
|
| 1468 |
+
|
| 1469 |
+
batch = q.shape[0]
|
| 1470 |
+
# kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 1471 |
+
cache_seqlens = (
|
| 1472 |
+
inference_params.lengths_per_sample[:batch]
|
| 1473 |
+
if inference_params.lengths_per_sample is not None
|
| 1474 |
+
else inference_params.seqlen_offset + 1 # +1 for the current token
|
| 1475 |
+
)
|
| 1476 |
+
# need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim)
|
| 1477 |
+
q = q.unsqueeze(2)
|
| 1478 |
+
k_cache = kv_cache[:, :, 0].unsqueeze(2)
|
| 1479 |
+
v_cache = kv_cache[:, :, 1].unsqueeze(2)
|
| 1480 |
+
|
| 1481 |
+
self.current_stream.wait_event(self.mha_router_event)
|
| 1482 |
+
|
| 1483 |
+
assert batch_head_idx is not None, "batch_head_idx must not be None"
|
| 1484 |
+
# Call select_attn
|
| 1485 |
+
context = select_attn(
|
| 1486 |
+
q,
|
| 1487 |
+
k_cache,
|
| 1488 |
+
v_cache,
|
| 1489 |
+
self.softmax_scale,
|
| 1490 |
+
batch_head_idx,
|
| 1491 |
+
cache_seqlens
|
| 1492 |
+
)
|
| 1493 |
+
|
| 1494 |
+
# context: (batch_size, seqlen_q=1, G=1, H, head_dim)
|
| 1495 |
+
# context = context.squeeze(2) # Remove G dimension
|
| 1496 |
+
context = context.view(batch, 1, self.num_heads_kv_per_rank, self.head_dim)
|
| 1497 |
+
return context
|
| 1498 |
+
|
| 1499 |
+
'''
|
| 1500 |
+
PYTHONWARNINGS="ignore" python -m HybridTensor.modules.SelectiveMHA --batch_size 8 --in_features 8192 --seq_len 512 --head_density 0.25
|
| 1501 |
+
'''
|
| 1502 |
+
|
| 1503 |
+
|
| 1504 |
+
if __name__ == "__main__":
|
| 1505 |
+
args = arg_parser()
|
| 1506 |
+
|
| 1507 |
+
max_seqlen = args.seq_len + 128
|
| 1508 |
+
max_batch_size = args.batch_size
|
| 1509 |
+
device = torch.device(f"cuda:{args.device}")
|
| 1510 |
+
|
| 1511 |
+
# simulates SelectiveMHA inference generation stage
|
| 1512 |
+
inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size)
|
| 1513 |
+
nheads = args.in_features // 128
|
| 1514 |
+
softmax_scale = 1 / (128 ** 0.5)
|
| 1515 |
+
rotary_emb_dim = 0
|
| 1516 |
+
|
| 1517 |
+
# build SelectiveMHA
|
| 1518 |
+
select_mha = SelectMHA(
|
| 1519 |
+
embed_dim=args.in_features,
|
| 1520 |
+
num_heads=nheads,
|
| 1521 |
+
num_heads_kv=None,
|
| 1522 |
+
causal=True,
|
| 1523 |
+
layer_idx=0,
|
| 1524 |
+
use_flash_attn=True,
|
| 1525 |
+
softmax_scale=softmax_scale,
|
| 1526 |
+
return_residual=False,
|
| 1527 |
+
rotary_emb_dim=rotary_emb_dim,
|
| 1528 |
+
device=device,
|
| 1529 |
+
dtype=torch.float16,
|
| 1530 |
+
)
|
| 1531 |
+
|
| 1532 |
+
standard_mha = SMHA(
|
| 1533 |
+
embed_dim=args.in_features,
|
| 1534 |
+
num_heads=nheads,
|
| 1535 |
+
num_heads_kv=None,
|
| 1536 |
+
causal=True,
|
| 1537 |
+
layer_idx=0,
|
| 1538 |
+
use_flash_attn=True,
|
| 1539 |
+
softmax_scale=softmax_scale,
|
| 1540 |
+
return_residual=False,
|
| 1541 |
+
rotary_emb_dim=rotary_emb_dim,
|
| 1542 |
+
device=device,
|
| 1543 |
+
dtype=torch.float16,
|
| 1544 |
+
)
|
| 1545 |
+
torch.cuda.empty_cache()
|
| 1546 |
+
torch.cuda.reset_max_memory_allocated()
|
| 1547 |
+
|
| 1548 |
+
with torch.no_grad():
|
| 1549 |
+
# prefill stage to generate kv cache for all batches
|
| 1550 |
+
og_x = torch.randn(args.batch_size, args.seq_len, args.in_features, device=device, dtype=torch.float16, requires_grad=False)
|
| 1551 |
+
|
| 1552 |
+
# out, time_ms = cuda_profiler(select_mha, og_x, inference_params=inference_params)
|
| 1553 |
+
# print(f"MHA Prefill time: {time_ms:.3f} ms")
|
| 1554 |
+
# out = select_mha(og_x, inference_params=inference_params)
|
| 1555 |
+
|
| 1556 |
+
# simulate kv cache, bug in flash_attn for larger batches
|
| 1557 |
+
kv = torch.randn(args.batch_size, args.seq_len, 2, nheads, 128, device=device, dtype=torch.float16, requires_grad=False)
|
| 1558 |
+
_ = _update_kv_cache(kv, inference_params, 0)
|
| 1559 |
+
|
| 1560 |
+
# increment the sequence length to move to the generation stage
|
| 1561 |
+
inference_params.seqlen_offset += args.seq_len
|
| 1562 |
+
|
| 1563 |
+
input_x = torch.randn(args.batch_size, 1, args.in_features, device=device, dtype=torch.float16, requires_grad=False)
|
| 1564 |
+
selected_heads = math.ceil(nheads * args.head_density)
|
| 1565 |
+
|
| 1566 |
+
# generate batch_head_idx for SelectiveMHA
|
| 1567 |
+
# batch_head_index = generate_BH_index(args.batch_size, nheads, selected_heads, device=device)
|
| 1568 |
+
batch_head_index = generate_random_BH_index(args.batch_size, nheads, selected_heads, device=device)
|
| 1569 |
+
# generatation stage Standard MHA
|
| 1570 |
+
out, standard_time_ms = cuda_profiler(standard_mha, input_x, inference_params=inference_params)
|
| 1571 |
+
print(f"Standard MHA time: {standard_time_ms:.3f} ms")
|
| 1572 |
+
|
| 1573 |
+
# generatation stage SelectiveMHA
|
| 1574 |
+
out, select_time_ms = cuda_profiler(select_mha, input_x, inference_params=inference_params, batch_head_idx=batch_head_index)
|
| 1575 |
+
print(f"SelectMHA time: {select_time_ms:.3f} ms")
|
| 1576 |
+
|
| 1577 |
+
speedup = standard_time_ms / select_time_ms
|
| 1578 |
+
print(f"Speedup: {speedup:.3f}")
|
| 1579 |
+
|
HybridTensor/modules/SelectiveMLP.py
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python -m HybridTensor.modules.SelectiveMLP --batch_size 8 --index_size 512
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.distributed import ProcessGroup
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
# import fused_dense_cuda # from apex
|
| 13 |
+
|
| 14 |
+
import fused_dense_lib as fused_dense_cuda
|
| 15 |
+
|
| 16 |
+
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
|
| 19 |
+
# from HybridTensor.modules.MLP import SelectiveMLPFunc
|
| 20 |
+
from HybridTensor.modules.references.fused_dense import ColumnParallelLinear, RowParallelLinear, fused_mlp_func
|
| 21 |
+
from HybridTensor.modules.references.MLP import SelectiveMLPTriton
|
| 22 |
+
from HybridTensor.utils.utils import arg_parser, sparse_index
|
| 23 |
+
from HybridTensor.utils.profiling import cuda_profiler
|
| 24 |
+
|
| 25 |
+
# compiles the kernels for the first time, takes time
|
| 26 |
+
from HybridTensor.triton.gather_gemm_col import gather_matmul_col
|
| 27 |
+
from HybridTensor.triton.gather_gemm_row import gather_matmul_row
|
| 28 |
+
|
| 29 |
+
# needs to be compiled before running
|
| 30 |
+
from HybridTensor.triton.heuristics.gather_gemm_col_h import gather_matmul_col as gather_matmul_col_h
|
| 31 |
+
from HybridTensor.triton.heuristics.gather_gemm_row_h import gather_matmul_row as gather_matmul_row_h
|
| 32 |
+
|
| 33 |
+
# from HybridTensor.triton.cg_safe.gather_gemm_col_cg import gather_matmul_col
|
| 34 |
+
# from HybridTensor.triton.cg_safe.gather_gemm_row_cg import gather_matmul_row
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, bias1 = None, bias2 = None, activation='relu', use_heuristic=True):
|
| 38 |
+
if use_heuristic:
|
| 39 |
+
out = gather_matmul_col_h(x, fc1_w, index_vec, bias = bias1, activations=activation)
|
| 40 |
+
out = gather_matmul_row_h(out, fc2_w, index_vec, bias = bias2)
|
| 41 |
+
else:
|
| 42 |
+
out = gather_matmul_col(x, fc1_w, index_vec, bias = bias1, activations=activation)
|
| 43 |
+
out = gather_matmul_row(out, fc2_w, index_vec, bias = bias2)
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# cg safe version
|
| 48 |
+
# def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, index_size, bias1 = None, bias2 = None, activation='relu', use_heuristic=True):
|
| 49 |
+
# out = gather_matmul_col(x, fc1_w, index_vec, index_size, bias = bias1, activations=activation)
|
| 50 |
+
# out = gather_matmul_row(out, fc2_w, index_vec, index_size, bias = bias2)
|
| 51 |
+
# return out
|
| 52 |
+
|
| 53 |
+
class MLPRouter(nn.Module):
|
| 54 |
+
def __init__(self, embed_dim, low_rank_dim, out_dim, act_th, device=None, dtype=None):
|
| 55 |
+
"""
|
| 56 |
+
Initializes the MHARouter class.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
embed_dim (int): Dimensionality of the input embeddings.
|
| 60 |
+
low_rank_dim (int): Dimensionality of the intermediate layer.
|
| 61 |
+
out_dim (int): Number of neurons.
|
| 62 |
+
"""
|
| 63 |
+
super(MLPRouter, self).__init__()
|
| 64 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 65 |
+
self.fc1 = nn.Linear(embed_dim, low_rank_dim, bias=False, **factory_kwargs)
|
| 66 |
+
self.fc2 = nn.Linear(low_rank_dim, out_dim, bias=False, **factory_kwargs)
|
| 67 |
+
self.act_th = act_th
|
| 68 |
+
self.num_neurons = out_dim
|
| 69 |
+
self.largest = self.num_neurons + 1
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
"""
|
| 73 |
+
Forward pass of the MHARouter.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
x (torch.Tensor): Input tensor of shape (batch_size, embed_dim).
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
torch.Tensor: Output tensor of shape (batch_size, num_heads).
|
| 80 |
+
"""
|
| 81 |
+
x = self.fc1(x)
|
| 82 |
+
x = self.fc2(x)
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
def _select_neurons_topk(self, x, topk=None):
|
| 86 |
+
neurons = self.forward(x)
|
| 87 |
+
|
| 88 |
+
neurons_nonzero = torch.nn.ReLU()(neurons)
|
| 89 |
+
_, index_vec = neurons_nonzero.sum(dim=0).topk(topk, dim=0, sorted=False)
|
| 90 |
+
# index_vec, _ = index_vec.sort()
|
| 91 |
+
return index_vec
|
| 92 |
+
|
| 93 |
+
def _select_neurons(self, x, th=None):
|
| 94 |
+
'''
|
| 95 |
+
Threshold based selection of neurons, not CG safe
|
| 96 |
+
'''
|
| 97 |
+
if th is None:
|
| 98 |
+
th = self.act_th
|
| 99 |
+
|
| 100 |
+
neurons = self.forward(x)
|
| 101 |
+
activated = (neurons > th).sum(dim=0)
|
| 102 |
+
index_vec = activated.nonzero().flatten()
|
| 103 |
+
return index_vec
|
| 104 |
+
|
| 105 |
+
def _select_neurons_cuda_safe(self, x, th=None):
|
| 106 |
+
'''
|
| 107 |
+
This function is used with threshold and is used for CG safe version of the code
|
| 108 |
+
'''
|
| 109 |
+
if th is None:
|
| 110 |
+
th = self.act_th
|
| 111 |
+
neurons = self.forward(x)
|
| 112 |
+
activated = (neurons > th).sum(dim=0)
|
| 113 |
+
|
| 114 |
+
indices = torch.arange(self.num_neurons, device=activated.device)
|
| 115 |
+
selected = torch.where(activated > th, indices, torch.full_like(indices, self.largest))
|
| 116 |
+
|
| 117 |
+
index_vec, _ = torch.sort(selected)
|
| 118 |
+
index_size = ((index_vec < self.largest).sum()).to(torch.int32)
|
| 119 |
+
|
| 120 |
+
return index_size, index_vec
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class ParallelMLPRouter(nn.Module):
|
| 125 |
+
"""
|
| 126 |
+
Parallel Sparse Predictor for MHA layer.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
embed_dim,
|
| 132 |
+
low_rank_dim,
|
| 133 |
+
out_dim,
|
| 134 |
+
act_th,
|
| 135 |
+
process_group,
|
| 136 |
+
sequence_parallel=False,
|
| 137 |
+
device=None,
|
| 138 |
+
dtype=None,
|
| 139 |
+
):
|
| 140 |
+
"""
|
| 141 |
+
Initializes the ParallelMHARouter class.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
embed_dim (int): Dimensionality of the input embeddings.
|
| 145 |
+
low_rank_dim (int): Dimensionality of the intermediate layer.
|
| 146 |
+
out_dim (int): Output dimensionality (typically number of neurons).
|
| 147 |
+
process_group (torch.distributed.ProcessGroup): Process group for parallelism.
|
| 148 |
+
sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False.
|
| 149 |
+
device (torch.device, optional): Device to run the module on. Defaults to None.
|
| 150 |
+
dtype (torch.dtype, optional): Data type of the module parameters. Defaults to None.
|
| 151 |
+
"""
|
| 152 |
+
super(ParallelMLPRouter, self).__init__()
|
| 153 |
+
assert process_group is not None, "ParallelMHARouter requires a process group."
|
| 154 |
+
|
| 155 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 156 |
+
self.process_group = process_group
|
| 157 |
+
self.embed_dim = embed_dim
|
| 158 |
+
self.act_th = act_th
|
| 159 |
+
|
| 160 |
+
self.fc1 = nn.Linear(
|
| 161 |
+
embed_dim, low_rank_dim, bias=False, **factory_kwargs
|
| 162 |
+
)
|
| 163 |
+
self.fc2 = ColumnParallelLinear(
|
| 164 |
+
low_rank_dim,
|
| 165 |
+
out_dim,
|
| 166 |
+
process_group,
|
| 167 |
+
bias=False,
|
| 168 |
+
sequence_parallel=sequence_parallel,
|
| 169 |
+
**factory_kwargs,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# def _select_neurons(self, neurons, th=None):
|
| 173 |
+
# if th is None:
|
| 174 |
+
# th = self.act_th
|
| 175 |
+
# activated = (neurons > th).sum(dim=0)
|
| 176 |
+
# index_vec = activated.nonzero().flatten()
|
| 177 |
+
# return index_vec
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
"""
|
| 181 |
+
Forward pass of the ParallelMHARouter.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, embed_dim).
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
torch.Tensor: Output tensor of shape (batch_size, seq_len, out_dim).
|
| 188 |
+
"""
|
| 189 |
+
x = self.fc1(x)
|
| 190 |
+
x = self.fc2(x)
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
def _select_neurons(self, x, th=None):
|
| 194 |
+
if th is None:
|
| 195 |
+
th = self.act_th
|
| 196 |
+
|
| 197 |
+
neurons = self.forward(x)
|
| 198 |
+
activated = (neurons > th).sum(dim=0)
|
| 199 |
+
index_vec = activated.nonzero().flatten()
|
| 200 |
+
return index_vec
|
| 201 |
+
|
| 202 |
+
def _select_neurons_topk(self, x, topk=None):
|
| 203 |
+
neurons = self.forward(x)
|
| 204 |
+
|
| 205 |
+
neurons_nonzero = torch.nn.ReLU()(neurons) #.squeeze(1)
|
| 206 |
+
# print(f"neurons_nonzero shape: {neurons_nonzero.shape}")
|
| 207 |
+
# print(f"Top k neurons: {topk}")
|
| 208 |
+
_, index_vec = neurons_nonzero.sum(dim=0).topk(topk, dim=0, sorted=False)
|
| 209 |
+
# index_vec, _ = index_vec.sort()
|
| 210 |
+
return index_vec
|
| 211 |
+
|
| 212 |
+
class SelectiveMLP(nn.Module):
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
in_features,
|
| 216 |
+
hidden_features=None,
|
| 217 |
+
out_features=None,
|
| 218 |
+
activation='relu',
|
| 219 |
+
layer_idx=None,
|
| 220 |
+
bias1=True,
|
| 221 |
+
bias2=True,
|
| 222 |
+
return_residual=False,
|
| 223 |
+
checkpoint_lvl=0,
|
| 224 |
+
use_heuristic=True,
|
| 225 |
+
device=None,
|
| 226 |
+
dtype=None,
|
| 227 |
+
):
|
| 228 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 229 |
+
super().__init__()
|
| 230 |
+
out_features = out_features if out_features is not None else in_features
|
| 231 |
+
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
| 232 |
+
self.return_residual = return_residual
|
| 233 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
| 234 |
+
self.activation = activation
|
| 235 |
+
self.activation_fn = nn.ReLU()
|
| 236 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 237 |
+
# self.fc2_weight_t = self.fc2.weight.t().contiguous()
|
| 238 |
+
self.fc2_weight_t = None
|
| 239 |
+
self.use_heuristic = use_heuristic
|
| 240 |
+
|
| 241 |
+
def _init_weights(self):
|
| 242 |
+
# if weights are updated, we need to update the transpose
|
| 243 |
+
self.fc2_weight_t = self.fc2.weight.t().contiguous()
|
| 244 |
+
|
| 245 |
+
def forward(self, x, index_vec=None, index_size=None):
|
| 246 |
+
|
| 247 |
+
if index_vec is not None:
|
| 248 |
+
# sparse forward,
|
| 249 |
+
|
| 250 |
+
# update on first run
|
| 251 |
+
if self.fc2_weight_t is None:
|
| 252 |
+
self.fc2_weight_t = self.fc2.weight.t().contiguous()
|
| 253 |
+
|
| 254 |
+
# Remove the original parameter to free memory.
|
| 255 |
+
self.fc2.weight = None
|
| 256 |
+
del self.fc2._parameters['weight']
|
| 257 |
+
|
| 258 |
+
x = x.view(-1, x.size(-1))
|
| 259 |
+
# x = x.squeeze(1)
|
| 260 |
+
y = SelectiveMLPFunc(x = x, fc1_w = self.fc1.weight,
|
| 261 |
+
fc2_w = self.fc2_weight_t, index_vec = index_vec,
|
| 262 |
+
bias1 = self.fc1.bias, bias2 = self.fc2.bias,
|
| 263 |
+
activation=self.activation, use_heuristic=self.use_heuristic)
|
| 264 |
+
|
| 265 |
+
else:
|
| 266 |
+
# dense forward
|
| 267 |
+
|
| 268 |
+
y = self.fc1(x)
|
| 269 |
+
y = self.activation_fn(y)
|
| 270 |
+
|
| 271 |
+
if self.fc2_weight_t is not None:
|
| 272 |
+
y = torch.matmul(y, self.fc2_weight_t)
|
| 273 |
+
else:
|
| 274 |
+
y = self.fc2(y)
|
| 275 |
+
|
| 276 |
+
return y if not self.return_residual else (y, x)
|
| 277 |
+
|
| 278 |
+
class ParallelSelectiveMLP(nn.Module):
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
in_features,
|
| 282 |
+
hidden_features,
|
| 283 |
+
out_features=None,
|
| 284 |
+
activation="relu",
|
| 285 |
+
layer_idx=None,
|
| 286 |
+
process_group: ProcessGroup = None,
|
| 287 |
+
bias1=True,
|
| 288 |
+
bias2=True,
|
| 289 |
+
return_residual=False,
|
| 290 |
+
sequence_parallel=False,
|
| 291 |
+
use_heuristic=True,
|
| 292 |
+
checkpoint_lvl=0,
|
| 293 |
+
heuristic="auto",
|
| 294 |
+
device=None,
|
| 295 |
+
dtype=None,
|
| 296 |
+
):
|
| 297 |
+
"""
|
| 298 |
+
process_group is required. We're doing Tensor Parallel with sequence parallelism:
|
| 299 |
+
we do an all_gather of x before doing the matmul, gelu, then matmul.
|
| 300 |
+
Finally we do a reduce_scatter of the output.
|
| 301 |
+
|
| 302 |
+
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
| 303 |
+
0: no recomputation in the bwd
|
| 304 |
+
1: recompute gelu_out in the bwd
|
| 305 |
+
2: recompute pre_act and gelu_out in the bwd
|
| 306 |
+
heuristic:
|
| 307 |
+
-1: don't fuse gemm + gelu (separate kernel)
|
| 308 |
+
0..4: use this heuristic for the algo section in the fused gemm + gelu
|
| 309 |
+
'auto': heuristic will be picked automatically:
|
| 310 |
+
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
|
| 311 |
+
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
| 312 |
+
"""
|
| 313 |
+
assert checkpoint_lvl in [0, 1, 2]
|
| 314 |
+
assert activation in ["gelu_approx", "relu"]
|
| 315 |
+
assert process_group is not None
|
| 316 |
+
# assert sp_kwargs != None, "sparse predictor parameters are not passed in."
|
| 317 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 318 |
+
super().__init__()
|
| 319 |
+
if out_features is None:
|
| 320 |
+
out_features = in_features
|
| 321 |
+
self.activation = activation
|
| 322 |
+
self.process_group = process_group
|
| 323 |
+
self.sequence_parallel = sequence_parallel
|
| 324 |
+
self.checkpoint_lvl = checkpoint_lvl
|
| 325 |
+
self.heuristic = heuristic
|
| 326 |
+
self.fc1 = ColumnParallelLinear(
|
| 327 |
+
in_features, hidden_features, process_group, bias=bias1, **factory_kwargs
|
| 328 |
+
)
|
| 329 |
+
self.fc2 = RowParallelLinear(
|
| 330 |
+
hidden_features, out_features, process_group, bias=bias2, **factory_kwargs
|
| 331 |
+
)
|
| 332 |
+
self.layer_idx = layer_idx
|
| 333 |
+
|
| 334 |
+
self.fc2_weight_t = self.register_buffer("fc2_weigth_t", None)
|
| 335 |
+
self.return_residual = return_residual
|
| 336 |
+
self.fc2_weight_t = None
|
| 337 |
+
self.use_heuristic = use_heuristic
|
| 338 |
+
self.reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 339 |
+
|
| 340 |
+
# self._init_weights()
|
| 341 |
+
|
| 342 |
+
def _init_weights(self):
|
| 343 |
+
# ffn2 weights needs to be in row major format to select from rows
|
| 344 |
+
self.fc2_weight_t = self.fc2.weight.t().contiguous()
|
| 345 |
+
|
| 346 |
+
def forward(self, x, residual = None, index_vec = None):
|
| 347 |
+
|
| 348 |
+
# do_token_generation = x.size(1) == 1
|
| 349 |
+
# index_vec = None
|
| 350 |
+
# with torch.cuda.stream(self.curr_stream):
|
| 351 |
+
if index_vec is not None:
|
| 352 |
+
# assert x.size(1) == 1
|
| 353 |
+
if self.fc2_weight_t is None:
|
| 354 |
+
self.fc2_weight_t = self.fc2.weight.t().contiguous()
|
| 355 |
+
|
| 356 |
+
x = x.view(-1, x.size(-1))
|
| 357 |
+
# x = rearrange(x, "b 1 d -> b d") # slightly more expensive to use rearrange
|
| 358 |
+
|
| 359 |
+
out = SelectiveMLPFunc(x = x, fc1_w = self.fc1.weight,
|
| 360 |
+
fc2_w = self.fc2_weight_t, index_vec = index_vec,
|
| 361 |
+
bias1 = self.fc1.bias, bias2 = self.fc2.bias,
|
| 362 |
+
activation=self.activation, use_heuristic=self.use_heuristic)
|
| 363 |
+
# out = rearrange(out, "b d -> b 1 d")
|
| 364 |
+
# out = out.view(-1, 1, out.size(-1))
|
| 365 |
+
|
| 366 |
+
else: # normal mlp
|
| 367 |
+
if self.heuristic == "auto":
|
| 368 |
+
dtype = (
|
| 369 |
+
x.dtype
|
| 370 |
+
if not torch.is_autocast_enabled()
|
| 371 |
+
else torch.get_autocast_gpu_dtype()
|
| 372 |
+
)
|
| 373 |
+
if self.activation == "gelu_approx":
|
| 374 |
+
cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
|
| 375 |
+
heuristic = (
|
| 376 |
+
0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
|
| 377 |
+
)
|
| 378 |
+
else:
|
| 379 |
+
heuristic = 0
|
| 380 |
+
else:
|
| 381 |
+
heuristic = self.heuristic
|
| 382 |
+
out = fused_mlp_func(
|
| 383 |
+
x,
|
| 384 |
+
self.fc1.weight,
|
| 385 |
+
self.fc2.weight,
|
| 386 |
+
self.fc1.bias,
|
| 387 |
+
self.fc2.bias,
|
| 388 |
+
activation=self.activation,
|
| 389 |
+
save_pre_act=self.training,
|
| 390 |
+
checkpoint_lvl=self.checkpoint_lvl,
|
| 391 |
+
heuristic=heuristic,
|
| 392 |
+
process_group=self.process_group,
|
| 393 |
+
sequence_parallel=self.sequence_parallel,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
if self.process_group.size() > 1:
|
| 397 |
+
# out = self.reduce_fn(out, self.process_group) # has some overhead,
|
| 398 |
+
dist.all_reduce(out, op=dist.ReduceOp.SUM, group=self.process_group)
|
| 399 |
+
|
| 400 |
+
return out if not self.return_residual else (out, x)
|
| 401 |
+
# return out
|
| 402 |
+
|
| 403 |
+
def sp_forward(self, x, residual = None, index_vec = None):
|
| 404 |
+
if self.heuristic == "auto":
|
| 405 |
+
dtype = (
|
| 406 |
+
x.dtype
|
| 407 |
+
if not torch.is_autocast_enabled()
|
| 408 |
+
else torch.get_autocast_gpu_dtype()
|
| 409 |
+
)
|
| 410 |
+
if self.activation == "gelu_approx":
|
| 411 |
+
cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
|
| 412 |
+
heuristic = (
|
| 413 |
+
0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
heuristic = 0
|
| 417 |
+
else:
|
| 418 |
+
heuristic = self.heuristic
|
| 419 |
+
curr_stream = torch.cuda.current_stream()
|
| 420 |
+
do_token_generation = x.size(1) == 1
|
| 421 |
+
# mlp_logit = None
|
| 422 |
+
|
| 423 |
+
# with torch.cuda.stream(self.curr_stream):
|
| 424 |
+
if index_vec != None:
|
| 425 |
+
assert x.size(1) == 1
|
| 426 |
+
|
| 427 |
+
if self.fc2_weight_t is None:
|
| 428 |
+
self.fc2_weight_t = self.fc2.weight.t().contiguous()
|
| 429 |
+
|
| 430 |
+
out = SelectiveMLPFunc(
|
| 431 |
+
rearrange(x, "b 1 d -> b d"),
|
| 432 |
+
self.fc1.weight,
|
| 433 |
+
self.fc2_weight_t,
|
| 434 |
+
index_vec,
|
| 435 |
+
self.fc1.bias,
|
| 436 |
+
self.fc2.bias,
|
| 437 |
+
activation=self.activation,
|
| 438 |
+
)
|
| 439 |
+
out = rearrange(out, "b d -> b 1 d")
|
| 440 |
+
else:
|
| 441 |
+
out = fused_mlp_func(
|
| 442 |
+
x,
|
| 443 |
+
self.fc1.weight,
|
| 444 |
+
self.fc2.weight,
|
| 445 |
+
self.fc1.bias,
|
| 446 |
+
self.fc2.bias,
|
| 447 |
+
activation=self.activation,
|
| 448 |
+
save_pre_act=self.training,
|
| 449 |
+
checkpoint_lvl=self.checkpoint_lvl,
|
| 450 |
+
heuristic=heuristic,
|
| 451 |
+
process_group=self.process_group,
|
| 452 |
+
sequence_parallel=self.sequence_parallel,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 457 |
+
if self.sp_router:
|
| 458 |
+
curr_stream.record_event(self.event_mlp)
|
| 459 |
+
|
| 460 |
+
# handle = torch.distributed.all_reduce(out, op=torch.distributed.ReduceOp.SUM, group=self.process_group, async_op=True)
|
| 461 |
+
out = reduce_fn(out, self.process_group)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
if self.sp_router:
|
| 465 |
+
with torch.cuda.stream(self.sp_stream):
|
| 466 |
+
self.sp_stream.wait_event(self.event_mlp)
|
| 467 |
+
if do_token_generation:
|
| 468 |
+
mlp_logit = self.sp(rearrange(residual, "b 1 d -> b d"))
|
| 469 |
+
self.sp_stream.record_event(self.event_mlp_sp)
|
| 470 |
+
|
| 471 |
+
# check this again, we might not have to synchronize here, we can synchronize in the next layer
|
| 472 |
+
curr_stream.wait_event(self.event_mlp_sp)
|
| 473 |
+
|
| 474 |
+
return out
|
| 475 |
+
|
| 476 |
+
class SimpleMLP(nn.Module):
|
| 477 |
+
def __init__(self, in_features, hidden_features, out_features, bias=False, activation="relu"):
|
| 478 |
+
super().__init__()
|
| 479 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 480 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 481 |
+
self.activation = activation
|
| 482 |
+
|
| 483 |
+
def forward(self, x):
|
| 484 |
+
x = F.relu(self.fc1(x))
|
| 485 |
+
x = self.fc2(x)
|
| 486 |
+
return x
|
| 487 |
+
|
| 488 |
+
if __name__ == "__main__":
|
| 489 |
+
args = arg_parser()
|
| 490 |
+
|
| 491 |
+
bias = True if args.bias > 0 else False
|
| 492 |
+
x = torch.randn(args.batch_size, args.in_features, device="cuda", dtype=torch.float16)
|
| 493 |
+
index_vec, _ = sparse_index(args.index_size, args.in_features*4)
|
| 494 |
+
|
| 495 |
+
'''
|
| 496 |
+
selective_mlp = SelectiveMLPTriton(args.in_features, args.hidden_features, bias=bias, device="cuda", dtype=torch.float16, activation="relu")
|
| 497 |
+
|
| 498 |
+
out, mlp_time = cuda_profiler(selective_mlp, x, index_vec)
|
| 499 |
+
|
| 500 |
+
out_col, col_time = cuda_profiler(gather_matmul_col, x, selective_mlp.fc1_w, index_vec, activations=selective_mlp.activation)
|
| 501 |
+
out_row, row_time = cuda_profiler(gather_matmul_row, out_col, selective_mlp.fc2_w, index_vec)
|
| 502 |
+
sum_time = col_time + row_time
|
| 503 |
+
|
| 504 |
+
print(f"Index size {args.index_size}, Activated {args.index_size/(args.in_features * 4)*100}% neurons")
|
| 505 |
+
|
| 506 |
+
print(f"Gather Col Time: {col_time} ms")
|
| 507 |
+
print(f"Gather Row Time: {row_time} ms")
|
| 508 |
+
# print(f"Sum Time: {sum_time} ms")
|
| 509 |
+
|
| 510 |
+
print(f"SelectiveMLP Time: {mlp_time} ms")
|
| 511 |
+
'''
|
| 512 |
+
|
| 513 |
+
in_features = args.in_features
|
| 514 |
+
hidden_features = in_features * 4
|
| 515 |
+
out_features = in_features
|
| 516 |
+
device = torch.device("cuda")
|
| 517 |
+
|
| 518 |
+
model = SelectiveMLP(
|
| 519 |
+
in_features, hidden_features, out_features, device=device, dtype=torch.float16, activation="relu", use_heuristic=True
|
| 520 |
+
).to(device)
|
| 521 |
+
|
| 522 |
+
router = MLPRouter(in_features, 1024, hidden_features, act_th = 0.5, device=device, dtype=torch.float16).to(device)
|
| 523 |
+
|
| 524 |
+
# Warm-up GPU
|
| 525 |
+
def warmup():
|
| 526 |
+
for _ in range(10):
|
| 527 |
+
_ = model(x, index_vec)
|
| 528 |
+
_ = model(x, None)
|
| 529 |
+
_ = router._select_neurons_topk(x, args.index_size)
|
| 530 |
+
|
| 531 |
+
warmup()
|
| 532 |
+
|
| 533 |
+
# Measure SelectiveMLPFunc speed
|
| 534 |
+
_, router_time = cuda_profiler(router._select_neurons_topk, x, args.index_size)
|
| 535 |
+
_, selective_time = cuda_profiler(model, x, index_vec)
|
| 536 |
+
# Measure dense forward speed
|
| 537 |
+
_, dense_time = cuda_profiler(model, x, None)
|
| 538 |
+
|
| 539 |
+
print(f"Router time per run: {router_time:.6f} ms")
|
| 540 |
+
print(f"SelectiveMLPFunc time per run: {selective_time:.6f} ms")
|
| 541 |
+
print(f"Dense forward time per run: {dense_time:.6f} ms")
|
| 542 |
+
print(f"Speedup: {dense_time / selective_time:.2f}x")
|
| 543 |
+
router_selective_time = router_time + selective_time
|
| 544 |
+
print(f"Router + SelectiveMLPFunc time per run: {router_selective_time:.6f} ms")
|
| 545 |
+
print(f"Speedup: {dense_time / router_selective_time:.2f}x")
|
| 546 |
+
############################################
|
| 547 |
+
# CUDA Graph capture tests for the MLP model
|
| 548 |
+
############################################
|
| 549 |
+
print("\n=== CUDA Graph Tests ===")
|
| 550 |
+
# --- Selective forward (sparse mode) ---
|
| 551 |
+
print("Testing CUDA Graph for Selective forward (with index_vec)...")
|
| 552 |
+
static_x = x.clone()
|
| 553 |
+
static_index_vec = index_vec.clone()
|
| 554 |
+
# Warm-up run to allocate memory
|
| 555 |
+
static_out_sel = model(static_x, index_vec=static_index_vec)
|
| 556 |
+
torch.cuda.synchronize()
|
| 557 |
+
|
| 558 |
+
# Capture on a non-default stream
|
| 559 |
+
capture_stream = torch.cuda.Stream()
|
| 560 |
+
with torch.cuda.stream(capture_stream):
|
| 561 |
+
g_sel = torch.cuda.CUDAGraph()
|
| 562 |
+
g_sel.capture_begin()
|
| 563 |
+
static_out_sel = model(static_x, index_vec=static_index_vec)
|
| 564 |
+
g_sel.capture_end()
|
| 565 |
+
torch.cuda.synchronize()
|
| 566 |
+
|
| 567 |
+
# Replay and check accuracy
|
| 568 |
+
g_sel.replay()
|
| 569 |
+
torch.cuda.synchronize()
|
| 570 |
+
cuda_sel_out = static_out_sel.clone()
|
| 571 |
+
regular_sel_out = model(x, index_vec=index_vec)
|
| 572 |
+
if torch.allclose(cuda_sel_out, regular_sel_out, atol=1e-3):
|
| 573 |
+
print("Selective forward CUDA Graph output matches regular output")
|
| 574 |
+
else:
|
| 575 |
+
print("Selective forward CUDA Graph output does NOT match regular output")
|
| 576 |
+
|
| 577 |
+
def replay_sel():
|
| 578 |
+
g_sel.replay()
|
| 579 |
+
_, selective_time_cuda = cuda_profiler(replay_sel)
|
| 580 |
+
print(f"Selective forward CUDA Graph time per run: {selective_time_cuda:.6f} ms")
|
HybridTensor/modules/SelectiveRouters.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import torch
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
def create_mlp_router_state_dict(router_files_dir):
|
| 7 |
+
"""
|
| 8 |
+
Loads all mlp_router weight files from the specified directory and creates a router_state_dict
|
| 9 |
+
with keys formatted as 'transformer.layers.{layer_num}.mlp_router.{param_name}'.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
router_files_dir (str): Path to the directory containing mlp_router_*.pt files.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
OrderedDict: A state dictionary suitable for loading into a transformer model.
|
| 16 |
+
"""
|
| 17 |
+
# Regular expression to extract layer number from filename
|
| 18 |
+
router_file_pattern = re.compile(r'mlp_router_(\d+)-[\d.]+-[\d.]+-[\d.]+\.pt$')
|
| 19 |
+
|
| 20 |
+
router_state_dict = OrderedDict()
|
| 21 |
+
|
| 22 |
+
# List all files in the directory
|
| 23 |
+
try:
|
| 24 |
+
all_files = os.listdir(router_files_dir)
|
| 25 |
+
except FileNotFoundError:
|
| 26 |
+
print(f"Error: Directory '{router_files_dir}' does not exist.")
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
# Filter files matching the pattern
|
| 30 |
+
router_files = [f for f in all_files if router_file_pattern.match(f)]
|
| 31 |
+
|
| 32 |
+
if not router_files:
|
| 33 |
+
print(f"No router files found in directory '{router_files_dir}'.")
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))):
|
| 37 |
+
match = router_file_pattern.match(file_name)
|
| 38 |
+
if not match:
|
| 39 |
+
print(f"Skipping file '{file_name}' as it does not match the pattern.")
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
layer_num = int(match.group(1))
|
| 43 |
+
file_path = os.path.join(router_files_dir, file_name)
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
# Load the router's state dict
|
| 47 |
+
router_weights = torch.load(file_path, map_location='cpu')
|
| 48 |
+
if not isinstance(router_weights, dict):
|
| 49 |
+
print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.")
|
| 50 |
+
continue
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"Error loading '{file_path}': {e}")
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
# Iterate through each parameter in the router's state dict
|
| 56 |
+
for param_name, param_tensor in router_weights.items():
|
| 57 |
+
# Construct the new key
|
| 58 |
+
new_key = f"transformer.layers.{layer_num}.mlp_router.{param_name}"
|
| 59 |
+
router_state_dict[new_key] = param_tensor
|
| 60 |
+
|
| 61 |
+
# print(f"Loaded router for layer {layer_num} from '{file_name}'.")
|
| 62 |
+
|
| 63 |
+
print(f"Total routers loaded: {len(router_state_dict) // 2}") # Assuming 4 params per router (weight & bias for 2 layers)
|
| 64 |
+
return router_state_dict
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def create_attn_router_state_dict(router_files_dir):
|
| 68 |
+
"""
|
| 69 |
+
Loads all attn_router weight files from the specified directory and creates a router_state_dict
|
| 70 |
+
with keys formatted as 'transformer.layers.{layer_num}.mha_router.{param_name}'.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
router_files_dir (str): Path to the directory containing attn_router_*.pt files.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
OrderedDict: A state dictionary suitable for loading into a transformer model.
|
| 77 |
+
"""
|
| 78 |
+
# Regular expression to extract layer number from filename
|
| 79 |
+
# Pattern: attn_router_{layer_num}-{value1}-{value2}.pt
|
| 80 |
+
router_file_pattern = re.compile(r'attn_router_(\d+)-[\d.]+-[\d.]+\.pt$')
|
| 81 |
+
|
| 82 |
+
router_state_dict = OrderedDict()
|
| 83 |
+
|
| 84 |
+
# List all files in the directory
|
| 85 |
+
try:
|
| 86 |
+
all_files = os.listdir(router_files_dir)
|
| 87 |
+
except FileNotFoundError:
|
| 88 |
+
print(f"Error: Directory '{router_files_dir}' does not exist.")
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
# Filter files matching the pattern
|
| 92 |
+
router_files = [f for f in all_files if router_file_pattern.match(f)]
|
| 93 |
+
|
| 94 |
+
if not router_files:
|
| 95 |
+
print(f"No attn_router files found in directory '{router_files_dir}'.")
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
# To handle potential duplicates, keep track of loaded layer numbers
|
| 99 |
+
loaded_layers = set()
|
| 100 |
+
|
| 101 |
+
for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))):
|
| 102 |
+
match = router_file_pattern.match(file_name)
|
| 103 |
+
if not match:
|
| 104 |
+
print(f"Skipping file '{file_name}' as it does not match the pattern.")
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
layer_num = int(match.group(1))
|
| 108 |
+
if layer_num in loaded_layers:
|
| 109 |
+
print(f"Warning: Multiple router files found for layer {layer_num}. Skipping '{file_name}'.")
|
| 110 |
+
continue # Skip duplicate layers
|
| 111 |
+
|
| 112 |
+
file_path = os.path.join(router_files_dir, file_name)
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
# Load the router's state dict
|
| 116 |
+
router_weights = torch.load(file_path, map_location='cpu')
|
| 117 |
+
if not isinstance(router_weights, dict):
|
| 118 |
+
print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.")
|
| 119 |
+
continue
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"Error loading '{file_path}': {e}")
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
# Iterate through each parameter in the router's state dict
|
| 125 |
+
for param_name, param_tensor in router_weights.items():
|
| 126 |
+
# Construct the new key
|
| 127 |
+
new_key = f"transformer.layers.{layer_num}.mha_router.{param_name}"
|
| 128 |
+
router_state_dict[new_key] = param_tensor
|
| 129 |
+
|
| 130 |
+
loaded_layers.add(layer_num)
|
| 131 |
+
# print(f"Loaded MHA router for layer {layer_num} from '{file_name}'.")
|
| 132 |
+
|
| 133 |
+
print(f"Total MHA routers loaded: {len(loaded_layers)}")
|
| 134 |
+
return router_state_dict
|
| 135 |
+
|
| 136 |
+
|
HybridTensor/modules/__init__.py
ADDED
|
File without changes
|
HybridTensor/modules/__pycache__/MLP.cpython-39.pyc
ADDED
|
Binary file (5.3 kB). View file
|
|
|
HybridTensor/modules/__pycache__/ParallelMLP.cpython-39.pyc
ADDED
|
Binary file (5.61 kB). View file
|
|
|
HybridTensor/modules/__pycache__/SelectiveBlock.cpython-39.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-310.pyc
ADDED
|
Binary file (6.56 kB). View file
|
|
|
HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-39.pyc
ADDED
|
Binary file (6.77 kB). View file
|
|
|
HybridTensor/modules/__pycache__/SelectiveMHA.cpython-310.pyc
ADDED
|
Binary file (23.3 kB). View file
|
|
|
HybridTensor/modules/__pycache__/SelectiveMHA.cpython-39.pyc
ADDED
|
Binary file (30.1 kB). View file
|
|
|
HybridTensor/modules/__pycache__/SelectiveMLP.cpython-310.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
HybridTensor/modules/__pycache__/SelectiveMLP.cpython-39.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
HybridTensor/modules/__pycache__/SelectiveRouters.cpython-310.pyc
ADDED
|
Binary file (3.96 kB). View file
|
|
|
HybridTensor/modules/__pycache__/SelectiveRouters.cpython-39.pyc
ADDED
|
Binary file (4 kB). View file
|
|
|
HybridTensor/modules/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|