diff --git a/HybridTensor/__init__.py b/HybridTensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/HybridTensor/__pycache__/__init__.cpython-310.pyc b/HybridTensor/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e57c6d984f6760e85dd4d65e1a63ff8bb3f13045 Binary files /dev/null and b/HybridTensor/__pycache__/__init__.cpython-310.pyc differ diff --git a/HybridTensor/__pycache__/__init__.cpython-39.pyc b/HybridTensor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..607f030119b25da41ed50024fddc0c976b429ace Binary files /dev/null and b/HybridTensor/__pycache__/__init__.cpython-39.pyc differ diff --git a/HybridTensor/benchmarks/generation/__pycache__/gen_util.cpython-39.pyc b/HybridTensor/benchmarks/generation/__pycache__/gen_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d070d856ac67e6c0c7792a3d81f9508b61b60491 Binary files /dev/null and b/HybridTensor/benchmarks/generation/__pycache__/gen_util.cpython-39.pyc differ diff --git a/HybridTensor/benchmarks/generation/__pycache__/llama_sparse_generation.cpython-39.pyc b/HybridTensor/benchmarks/generation/__pycache__/llama_sparse_generation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ec2d6f5412cb817da4360c3467795da43209219 Binary files /dev/null and b/HybridTensor/benchmarks/generation/__pycache__/llama_sparse_generation.cpython-39.pyc differ diff --git a/HybridTensor/benchmarks/generation/__pycache__/model_sparse_generation.cpython-39.pyc b/HybridTensor/benchmarks/generation/__pycache__/model_sparse_generation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c00514682cf77b9fa5f2db3d22ce5cd24215b8dd Binary files /dev/null and b/HybridTensor/benchmarks/generation/__pycache__/model_sparse_generation.cpython-39.pyc differ diff --git a/HybridTensor/benchmarks/generation/__pycache__/opt_gen_tp.cpython-39.pyc b/HybridTensor/benchmarks/generation/__pycache__/opt_gen_tp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1df251731e84d64dc0d612f73596bae824a25ea Binary files /dev/null and b/HybridTensor/benchmarks/generation/__pycache__/opt_gen_tp.cpython-39.pyc differ diff --git a/HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-310.pyc b/HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3938243c96fa51d0fb9e5c94181165379a0c5895 Binary files /dev/null and b/HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-310.pyc differ diff --git a/HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-39.pyc b/HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..945f2a86641f2d4800e383dc7eb503b2fd73d6d9 Binary files /dev/null and b/HybridTensor/benchmarks/generation/__pycache__/opt_generation.cpython-39.pyc differ diff --git a/HybridTensor/benchmarks/generation/__pycache__/opt_sparse_gen_tp.cpython-39.pyc b/HybridTensor/benchmarks/generation/__pycache__/opt_sparse_gen_tp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4616a4e4e9129f5b31a33a8b9a5825a73462414 Binary files /dev/null and b/HybridTensor/benchmarks/generation/__pycache__/opt_sparse_gen_tp.cpython-39.pyc differ diff --git a/HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-310.pyc b/HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..113622f4bc42a68c3527dd2a3fd512bc7be9bba7 Binary files /dev/null and b/HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-310.pyc differ diff --git a/HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-39.pyc b/HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57d229ab1fa7f529b532a17b439169f4db0ac982 Binary files /dev/null and b/HybridTensor/benchmarks/generation/__pycache__/opt_sparse_generation.cpython-39.pyc differ diff --git a/HybridTensor/benchmarks/generation/gen_util.py b/HybridTensor/benchmarks/generation/gen_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e3640d3d981829b0515917e9886b4225b7e01912 --- /dev/null +++ b/HybridTensor/benchmarks/generation/gen_util.py @@ -0,0 +1,38 @@ +import random +import torch +from datasets import load_dataset +from transformers import AutoTokenizer + +def tokenize_dataset(dataset, tokenizer): + # Tokenize and concatenate all texts (without adding special tokens) + all_tokens = [] + for example in dataset: + tokens = tokenizer(example["text"], add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + return all_tokens + +def get_random_batch(tokens, batch_size, seq_length): + total = len(tokens) + batch = [] + for _ in range(batch_size): + start = random.randint(0, total - seq_length) + batch.append(tokens[start : start + seq_length]) + return torch.tensor(batch) + +''' +# Load dataset and tokenizer +dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") +model_name = "facebook/opt-6.7b" +tokenizer = AutoTokenizer.from_pretrained(model_name) + +tokens = tokenize_dataset(dataset, tokenizer) + +# Define parameters +batch_size = 8 +seq_length = 2000 + +random_batch = get_random_batch(tokens, batch_size, seq_length) +print("Batch shape:", random_batch.shape) # Expected: (8, 128) + +''' + diff --git a/HybridTensor/benchmarks/generation/model_sparse_generation.py b/HybridTensor/benchmarks/generation/model_sparse_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..a3402e0a0732942870006ff8acee66f77da23377 --- /dev/null +++ b/HybridTensor/benchmarks/generation/model_sparse_generation.py @@ -0,0 +1,71 @@ +import torch +import argparse + +from HybridTensor.utils.utils import _get_device +from HybridTensor.utils.activations import MODELS +from HybridTensor.models.opt import build_sparse_opt +from HybridTensor.models.llama import build_sparse_llama +from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv +from transformers import AutoTokenizer + + +def update_router_config(model, num_layers, mlp_topk_lookup, attn_topk): + for i in range(num_layers): + if mlp_topk_lookup is not None: + model.transformer.layers[i].mlp_topk = mlp_topk_lookup[i] + # model.transformer.layers[i].mlp_topk = 512 + model.transformer.layers[i].mha_router.topk = attn_topk + + # dense attention in layer 0 + model.transformer.layers[0].mha_router.topk = 1.0 + +def arg_parser(): + parser = argparse.ArgumentParser(description='Inference benchmarking') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--model_index', type=int, default=5) + parser.add_argument('--print_results', type=bool, default=True) + parser.add_argument('--iterations', type=int, default=1) + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model') + parser.add_argument('--mlp_ckpt_dir', type=str, default='/home/grads/s//nvme/HybridTensor/checkpoint/opt-6.7b-routers/mlp') + parser.add_argument('--attn_ckpt_dir', type=str, default='/home/grads/s//nvme/HybridTensor/checkpoint/opt-6.7b-routers/mha_linear') + parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router/opt-6.7b') + parser.add_argument('--delta', type=int, default=256, help='Delta value for MLP topk calculation') + parser.add_argument('--use_cuda_graph', type=bool, default=False, help='Use CUDA graph for inference') + + return parser.parse_args() + +if __name__ == "__main__": + args = arg_parser() + model_name = MODELS[args.model_index-1] + print(f"Model name: {model_name}") + dtype = torch.float16 + device= _get_device(args.gpu) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + if "llama" in model_name: + model = build_sparse_llama(args, model_name, + args.attn_ckpt_dir, + device = device, dtype=dtype) + update_router_config(model, model.config.n_layer, None, args.attn_topk) # this sets the router config for all layers using a single config + + else: + mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, args.batch_size) + model = build_sparse_opt(args, model_name, + args.mlp_ckpt_dir, + args.attn_ckpt_dir, + device = device, dtype=dtype) + update_router_config(model, model.config.n_layer, mlp_topk_lookup, args.attn_topk) # this sets the router config for all layers using a single config + + + model.eval() + print(model) + + # test input + input_text = "Once upon a time in a land far, far away, there lived a" + input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) + + # Generate output + with torch.no_grad(): + output = model.generate(input_ids, max_length=50) + print(tokenizer.decode(output[0], skip_special_tokens=True)) diff --git a/HybridTensor/benchmarks/generation/opt_gen_tp.py b/HybridTensor/benchmarks/generation/opt_gen_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..dd953c6aca0204843b39ef71671a8fd4280b6578 --- /dev/null +++ b/HybridTensor/benchmarks/generation/opt_gen_tp.py @@ -0,0 +1,133 @@ +from transformers.models.opt import OPTConfig +from transformers import AutoTokenizer +from flash_attn.models.opt import opt_config_to_gpt2_config + +import os +import torch +import argparse +from apex.transformer import parallel_state + +from HybridTensor.utils.utils import arg_parser, _get_device +from HybridTensor.utils.activations import OPT_MODELS +from HybridTensor.models.opt import SparseConfig, build_sparse_opt, build_dense_opt + + +def initialize_distributed_environment(): + # Set environment variables for NCCL + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" + os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0" + + # Initialize the distributed process group + torch.distributed.init_process_group(backend="nccl", init_method="env://") + + # Set the device based on the rank of the current process + device = f"cuda:{torch.distributed.get_rank()}" + world_size = torch.distributed.get_world_size() + + # Set the current CUDA device to avoid operations being executed on the wrong GPU + torch.cuda.set_device(device) + + # You can return device, world_size, and any other relevant information + return device, world_size + +def _turn_bias_off(model, num_layers): + for i in range(num_layers): + model.transformer.layers[i].mlp.fc1.bias = None + model.transformer.layers[i].mlp.fc2.bias = None + +def arg_parser(): + parser = argparse.ArgumentParser(description='Inference benchmarking') + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--model_index', type=int, default=5) + parser.add_argument('--seq_len', type=int, default=25) + parser.add_argument('--index_size', type=int, default=8192) + parser.add_argument('--head_density', type=float, default=0.25) + parser.add_argument('--print_results', type=bool, default=True) + parser.add_argument('--iterations', type=int, default=2) + parser.add_argument('--check_results', type=bool, default=False) + parser.add_argument('--results_dir', type=str, default='results') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--bias', type=bool, default=False) + parser.add_argument('--mlp_ckpt_dir', type=str, default='/home/grads/s//nvme/HybridTensor/checkpoint/opt-6.7b-routers/mlp') + parser.add_argument('--attn_ckpt_dir', type=str, default='/home/grads/s//nvme/HybridTensor/checkpoint/opt-6.7b-routers/mha_linear') + + return parser.parse_args() + +if __name__ == "__main__": + + args = arg_parser() + model_name = OPT_MODELS[args.model_index-1] + + device, world_size = initialize_distributed_environment() + dtype = torch.float16 + + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + process_group = parallel_state.get_tensor_model_parallel_group() + + tokenizer = AutoTokenizer.from_pretrained(model_name) + # model = build_sparse_opt(model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype, process_group = process_group, world_size = world_size, rank = rank) + model = build_dense_opt(model_name, process_group = process_group, world_size = world_size, rank = rank, device = device, dtype=dtype) + model.eval() + # if rank == 0: + # print(model) + + # input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "] + input_texts = ["In a distant galaxy, a spaceship"] + tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(device) + input_ids=tokenized_inputs["input_ids"] + # input_ids = tokenizer("Hello, my dog is cute and", return_tensors="pt").input_ids.to(device=device) + + max_length = args.seq_len + position_ids = None + eos_token_id = tokenizer.eos_token_id + num_layers = model.config.n_layer + + # turn bias off for mlp layers + if not args.bias: + _turn_bias_off(model, num_layers) + + _ = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + ) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + for i in range(args.iterations): + out = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + ) + + end_event.record() + + torch.cuda.synchronize() + + # print(tokenizer.batch_decode(out.sequences.tolist())) + + if rank == 0: + elapsed_time = start_event.elapsed_time(end_event) / args.iterations + print(f"Average time per genearation : {elapsed_time} ms") + + # Compute throughput and latency per token + num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1] + throughput = num_tokens_generated / (elapsed_time / 1000) # tokens per second + latency_per_token = elapsed_time / num_tokens_generated # ms per token + + print(f"Number of tokens generated: {num_tokens_generated}") + print(f"Throughput: {throughput} tokens/second") + print(f"Latency per token: {latency_per_token} ms") + print(tokenizer.batch_decode(out.sequences.tolist())) + diff --git a/HybridTensor/benchmarks/generation/opt_generation.py b/HybridTensor/benchmarks/generation/opt_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..55eb9afbef08ce4efb4a3fa6c70ab08f4f307704 --- /dev/null +++ b/HybridTensor/benchmarks/generation/opt_generation.py @@ -0,0 +1,289 @@ +import re +import time + +import pytest +import torch +import argparse + +from einops import rearrange + +from HybridTensor.benchmarks.generation.gen_util import tokenize_dataset, get_random_batch +from HybridTensor.utils.activations import OPT_MODELS +from datasets import load_dataset + +from flash_attn.models.gpt import GPTLMHeadModel +from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt +from flash_attn.utils.generation import update_graph_cache +from flash_attn.utils.pretrained import state_dict_from_pretrained +from transformers import AutoTokenizer, OPTConfig +from transformers.models.opt.modeling_opt import OPTForCausalLM + +def test_opt_generation(model_name): + """Check that our implementation of OPT generation matches the HF implementation: + the scores in fp16 should be around the same as the HF scores in fp16, when compared to + the HF scores in fp32. + """ + print(f"\nMODEL: {model_name}") + verbose = False + dtype = torch.float16 + device = "cuda" + rtol, atol = 3e-3, 3e-1 + config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) + # Only prenorm supports residual_in_fp32 + config.residual_in_fp32 = getattr(config, "prenorm", True) + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = True + + model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) + model.eval() + + torch.manual_seed(0) + # OPT tokenizer requires use_fast=False + # https://huggingface.co/docs/transformers/model_doc/opt + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) + eos_token_id = tokenizer.eos_token_id + + input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to( + device=device + ) + max_length = 25 + # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') + # max_length = input_ids.shape[1] + 40 + + # Slow generation for reference + sequences = [] + scores = [] + cur_input_ids = input_ids + with torch.inference_mode(): + scores.append(model(cur_input_ids).logits[:, -1]) + sequences.append(scores[-1].argmax(dim=-1)) + for _ in range(input_ids.shape[1] + 1, max_length): + cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1) + scores.append(model(cur_input_ids).logits[:, -1]) + sequences.append(scores[-1].argmax(dim=-1)) + if eos_token_id is not None and (sequences[-1] == eos_token_id).all(): + break + sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) + scores = tuple(scores) + + print("Without CUDA graph") + torch.cuda.synchronize() + start = time.time() + out = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=True, + ) + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + if verbose: + print(out.sequences) + print(tokenizer.batch_decode(out.sequences.tolist())) + if getattr(config, "use_flash_attn", False): + # Capture graph outside the timing loop + batch_size, seqlen_og = input_ids.shape + model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + print("With CUDA graph") + torch.cuda.synchronize() + start = time.time() + out_cg = model.generate( + input_ids=input_ids, + max_length=max_length, + cg=True, + return_dict_in_generate=True, + output_scores=True, + enable_timing=True, + ) + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + if verbose: + print(out_cg.sequences) + print(tokenizer.batch_decode(out_cg.sequences.tolist())) + + del model + + model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device) + model_hf.eval() + print("HF fp16") + torch.cuda.synchronize() + start = time.time() + out_hf = model_hf.generate( + input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True + ) + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + del model_hf + + model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device) + model_ref.eval() + print("HF fp32") + torch.cuda.synchronize() + start = time.time() + out_ref = model_ref.generate( + input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True + ) + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + del model_ref + print(tokenizer.batch_decode(out_ref.sequences.tolist())) + + if verbose: + print( + f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}" + ) + print( + f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}" + ) + print( + f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}" + ) + print( + f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}" + ) + + assert torch.all(out.sequences == sequences) + assert torch.allclose( + torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol + ) + assert torch.all(out.sequences == out_ref.sequences) + assert torch.all(out.sequences == out_hf.sequences) + + assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * ( + torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1) + ).abs().max().item() + + +def arg_parser(): + parser = argparse.ArgumentParser(description='Inference benchmarking') + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--model_index', type=int, default=5) + parser.add_argument('--seq_len', type=int, default=1024) + parser.add_argument('--index_size', type=int, default=8192) + parser.add_argument('--head_density', type=float, default=0.25) + parser.add_argument('--print_results', type=bool, default=False) + parser.add_argument('--iterations', type=int, default=1) + parser.add_argument('--check_results', type=bool, default=False) + parser.add_argument('--results_dir', type=str, default='results') + parser.add_argument('--gpu', type=int, default=0) + + return parser.parse_args() + +if __name__ == "__main__": + + args = arg_parser() + model_name = OPT_MODELS[args.model_index-1] + # test_opt_generation(model_name) + + print(f"\nMODEL: {model_name}\n") + verbose = False + dtype = torch.float16 + device = "cuda" + rtol, atol = 3e-3, 3e-1 + config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) + # Only prenorm supports residual_in_fp32 + config.residual_in_fp32 = getattr(config, "prenorm", True) + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = True + + model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) + model.eval() + + torch.manual_seed(0) + # OPT tokenizer requires use_fast=False + # https://huggingface.co/docs/transformers/model_doc/opt + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) + eos_token_id = tokenizer.eos_token_id + + # input_ids = tokenizer("In a distant galaxy, a spaceship", return_tensors="pt").input_ids.to( + # device=device + # ) + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + tokens = tokenize_dataset(dataset, tokenizer) + input_ids = get_random_batch(tokens, args.batch_size, args.seq_len) + input_ids = input_ids.to(device=device) + max_length = args.seq_len + 20 + + # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') + # max_length = input_ids.shape[1] + 40 + + # warm up + _ = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + ) + + print("Without CUDA graph") + torch.cuda.synchronize() + start = time.time() + out = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + ) + torch.cuda.synchronize() + elapsed_time = (time.time() - start) * 1000 + print(f"Prompt processing + decoding time: {elapsed_time:.0f} ms") + + # Compute throughput and latency per token + num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1] + throughput = (args.batch_size * num_tokens_generated) / (elapsed_time / 1000) + latency_per_token = elapsed_time / num_tokens_generated # ms per token + + # print(f"Number of tokens generated: {num_tokens_generated}") + print(f"Throughput: {throughput:.1f} tokens/second") + print(f"Latency per token: {latency_per_token:.1f} ms") + + + if args.print_results: + # print(out.sequences) + print(tokenizer.batch_decode(out.sequences.tolist())) + + # ============================================================================= # + + print("\n") + if getattr(config, "use_flash_attn", False): + # Capture graph outside the timing loop + batch_size, seqlen_og = input_ids.shape + model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + print("With CUDA graph") + torch.cuda.synchronize() + start = time.time() + out_cg = model.generate( + input_ids=input_ids, + max_length=max_length, + cg=True, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + ) + torch.cuda.synchronize() + elapsed_time = (time.time() - start) * 1000 + print(f"Prompt processing + decoding time: {elapsed_time:.0f} ms") + + # Compute throughput and latency per token + num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1] + latency_per_token = elapsed_time / num_tokens_generated # ms per token + throughput = (args.batch_size * num_tokens_generated) / (elapsed_time / 1000) + + # print(f"Number of tokens generated: {num_tokens_generated}") + print(f"Throughput: {throughput:.1f} tokens/second") + print(f"Latency per token: {latency_per_token:.1f} ms") + + if args.print_results: + # print(out_cg.sequences) + print(tokenizer.batch_decode(out_cg.sequences.tolist())) \ No newline at end of file diff --git a/HybridTensor/benchmarks/generation/opt_sparse_gen_tp.py b/HybridTensor/benchmarks/generation/opt_sparse_gen_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..c8efad9949202e31dc42f889097037834a177f5a --- /dev/null +++ b/HybridTensor/benchmarks/generation/opt_sparse_gen_tp.py @@ -0,0 +1,112 @@ +from transformers.models.opt import OPTConfig +from transformers import AutoTokenizer +from flash_attn.models.opt import opt_config_to_gpt2_config + +import os +import torch +import argparse +from apex.transformer import parallel_state + +from HybridTensor.utils.utils import arg_parser, _get_device +from HybridTensor.utils.activations import OPT_MODELS +from HybridTensor.models.opt import SparseConfig, build_sparse_opt + +def update_router_config(model, num_layers, mlp_act_th, attn_topk, layer_config = None): + for i in range(num_layers): + model.transformer.layers[i].mlp_router.act_th = mlp_act_th + model.transformer.layers[i].mha_router.topk = attn_topk + +def initialize_distributed_environment(): + # Set environment variables for NCCL + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" + os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0" + + # Initialize the distributed process group + torch.distributed.init_process_group(backend="nccl", init_method="env://") + + # Set the device based on the rank of the current process + device = f"cuda:{torch.distributed.get_rank()}" + world_size = torch.distributed.get_world_size() + + # Set the current CUDA device to avoid operations being executed on the wrong GPU + torch.cuda.set_device(device) + + # You can return device, world_size, and any other relevant information + return device, world_size + + +def arg_parser(): + parser = argparse.ArgumentParser(description='Inference benchmarking') + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--model_index', type=int, default=5) + parser.add_argument('--seq_len', type=int, default=28) + parser.add_argument('--index_size', type=int, default=8192) + parser.add_argument('--head_density', type=float, default=0.25) + parser.add_argument('--print_results', type=bool, default=True) + parser.add_argument('--iterations', type=int, default=100) + parser.add_argument('--check_results', type=bool, default=False) + parser.add_argument('--results_dir', type=str, default='results') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--mlp_ckpt_dir', type=str, default='') + parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model') + parser.add_argument('--attn_ckpt_dir', type=str, default='') + + return parser.parse_args() + +if __name__ == "__main__": + + args = arg_parser() + model_name = OPT_MODELS[args.model_index-1] + + device, world_size = initialize_distributed_environment() + dtype = torch.float16 + + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + process_group = parallel_state.get_tensor_model_parallel_group() + + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = build_sparse_opt(model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype, process_group = process_group, world_size = world_size, rank = rank) + model.eval() + print("Model loaded with sparse routers") + + mlp_act_th = 0.5 + attn_topk = 0.5 + + update_router_config(model, model.config.n_layer, mlp_act_th, attn_topk) + print("Router config updated") + + # print router configs from all layers + # for i in range(model.config.n_layer): + # print(f"Layer {i}: mlp_act_th = {model.transformer.layers[i].mlp_router.act_th}, attn_topk = {model.transformer.layers[i].mha_router.topk}") + + input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "] + # input_texts = ["Hello, my dog is cute and", "Hello, my rat is cute and"] + + tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(device) + input_ids=tokenized_inputs["input_ids"] + + # input_ids = tokenizer("Hello, my dog is cute and", return_tensors="pt").input_ids.to(device=device) + max_length = args.seq_len + position_ids = None + eos_token_id = tokenizer.eos_token_id + num_layers = model.config.n_layer + + # print all the model weights and check the accuracy + # if rank == 0: + # print(model.state_dict()) + + # out = model(input_ids) + # print(out) + + out = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=True, + ) + if rank == 0: + print(tokenizer.batch_decode(out.sequences.tolist())) + diff --git a/HybridTensor/benchmarks/generation/opt_sparse_generation.py b/HybridTensor/benchmarks/generation/opt_sparse_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..92fdbd519299329959fb65d96ebbc90b651bec03 --- /dev/null +++ b/HybridTensor/benchmarks/generation/opt_sparse_generation.py @@ -0,0 +1,182 @@ +import torch +import argparse + +from HybridTensor.utils.utils import _get_device +from HybridTensor.utils.activations import OPT_MODELS +from HybridTensor.models.opt import SparseConfig, build_sparse_opt +from HybridTensor.benchmarks.generation.gen_util import tokenize_dataset, get_random_batch +from HybridTensor.utils.activations import build_mlp_topk_lookup +from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv + +from datasets import load_dataset + +from transformers.models.opt import OPTConfig +from transformers import AutoTokenizer +from flash_attn.models.opt import opt_config_to_gpt2_config +from flash_attn.utils.generation import update_graph_cache + +def arg_parser(): + parser = argparse.ArgumentParser(description='Inference benchmarking') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--model_index', type=int, default=5) + parser.add_argument('--seq_len', type=int, default=1024) + parser.add_argument('--index_size', type=int, default=8192) + parser.add_argument('--head_density', type=float, default=0.5) + parser.add_argument('--print_results', type=bool, default=True) + parser.add_argument('--iterations', type=int, default=1) + parser.add_argument('--check_results', type=bool, default=False) + parser.add_argument('--results_dir', type=str, default='results') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model') + parser.add_argument('--mlp_ckpt_dir', type=str, default='') + parser.add_argument('--attn_ckpt_dir', type=str, default='') + parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router/opt-6.7b') + parser.add_argument('--delta', type=int, default=256, help='Delta value for MLP topk calculation') + parser.add_argument('--use_cuda_graph', type=bool, default=False, help='Use CUDA graph for inference') + + return parser.parse_args() + +def update_router_config(model, num_layers, mlp_topk_lookup, attn_topk): + for i in range(num_layers): + model.transformer.layers[i].mlp_topk = mlp_topk_lookup[i] + # model.transformer.layers[i].mlp_topk = 512 + model.transformer.layers[i].mha_router.topk = attn_topk + + # model.transformer.layers[i].skip_mlp_router = True + model.transformer.layers[0].mha_router.topk = 1.0 # dense attention in layer 0 + +if __name__ == "__main__": + args = arg_parser() + model_name = OPT_MODELS[args.model_index-1] + dtype = torch.float16 + device= _get_device(args.gpu) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # args.mlp_ckpt_dir = None + # args.attn_ckpt_dir = None + + model = build_sparse_opt(args, model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype) + model.eval() + print(model) + print("Model loaded with sparse routers") + + # mlp_topk_lookup = build_mlp_topk_lookup("results/mlp_results/batch_activations/opt-6.7b", args.batch_size, args.delta) + mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, args.batch_size) + print("MLP topk values updated: ", mlp_topk_lookup) + update_router_config(model, model.config.n_layer, mlp_topk_lookup, args.attn_topk) # this sets the router config for all layers using a single config + # update_router_config(model, model.config.n_layer, 2048, args.attn_topk) + print("Router config updated \n") + + + max_length = args.seq_len + 20 + batch_size = args.batch_size + + # input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "] + # input_texts = ["In a distant galaxy, a spaceship"] + # tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=False).to(device) + # input_ids=tokenized_inputs["input_ids"] + + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + tokens = tokenize_dataset(dataset, tokenizer) + input_ids = get_random_batch(tokens, args.batch_size, args.seq_len).to(device) + + print("Input ids generated, starting inference") + + # input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(device) + position_ids = None + eos_token_id = tokenizer.eos_token_id + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + with torch.no_grad(): + # warm up + _ = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + cg=False, + ) + + print("Warm up done") + + start_event.record() + for i in range(args.iterations): + out = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + cg=False, + ) + + end_event.record() + + torch.cuda.synchronize() + print("Without CUDA graph") + elapsed_time = start_event.elapsed_time(end_event) / args.iterations + print(f"Average time per genearation : {elapsed_time:.1f} ms") + + # Compute throughput and latency per token + num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1] + throughput = batch_size * num_tokens_generated / (elapsed_time / 1000) # tokens per second + latency_per_token = elapsed_time / num_tokens_generated # ms per token + + print(f"Number of tokens generated: {num_tokens_generated}") + print(f"Throughput: {throughput:.1f} tokens/second") + print(f"Latency per token: {latency_per_token:.1f} ms") + + # print(tokenizer.batch_decode(out.sequences.tolist())) + print("\n") + + # print only the new tokens generated + print("New tokens generated:") + print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist())) + + # ====================== With CUDA graph ====================== + if args.use_cuda_graph: + batch_size, seqlen_og = input_ids.shape + model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + print("With CUDA graph") + torch.cuda.synchronize() + + start_event.record() + + for i in range(args.iterations): + out = model.generate( + input_ids=input_ids, + max_length=max_length, + cg=True, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + ) + + end_event.record() + + torch.cuda.synchronize() + + + elapsed_time = start_event.elapsed_time(end_event) / args.iterations + print(f"Average time per genearation : {elapsed_time:.1f} ms") + + # Compute throughput and latency per token + num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1] + throughput = batch_size * num_tokens_generated / (elapsed_time / 1000) # tokens per second + latency_per_token = elapsed_time / num_tokens_generated # ms per token + + print(f"Number of tokens generated: {num_tokens_generated}") + print(f"Throughput: {throughput:.1f} tokens/second") + print(f"Latency per token: {latency_per_token:.1f} ms") + + # print(tokenizer.batch_decode(out.sequences.tolist())) + print("New tokens generated:") + print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist())) + \ No newline at end of file diff --git a/HybridTensor/benchmarks/model_eval.py b/HybridTensor/benchmarks/model_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..37490cbbe420de33a89160fb9950559f419c7ff9 --- /dev/null +++ b/HybridTensor/benchmarks/model_eval.py @@ -0,0 +1,313 @@ +import torch +import argparse +import os +import json +import logging +import numpy as np +import csv + +# from hf_models.opt.modeling_opt_routers import ( +# SparseOPTForCausalLM, +# create_hf_mha_router_state_dict, +# create_hf_mlp_router_state_dict +# ) + +from hf_models.opt.modeling_opt_routers_topk import ( + SparseOPTForCausalLM, + create_hf_mha_router_state_dict, + create_hf_mlp_router_state_dict +) + +from hf_models.llama.modeling_sparse_llama_routers import ( + SparseLlamaForCausalLM, + create_hf_attn_router_state_dict +) + +from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM as SparseOPTTopKAttn +from hf_models.llama.modeling_sparse_llama_mha_topk import SparseLlamaForCausalLM as SparseLlamaTopKAttn +from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import _update_model_attn_thresholds +from HybridTensor.benchmarks.model_perplexity import compute_attn_layer_sparsity, compute_average_activation +from HybridTensor.utils.activations import ActivationThresholds, build_mlp_topk_lookup, _update_hf_mlp_topk, CONFIGS, MODELS +from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv +from HybridTensor.utils.utils import extract_model_name + +from transformers import AutoTokenizer, AutoModelForCausalLM + +from lm_eval.models.huggingface import HFLM +from lm_eval.tasks import TaskManager +import lm_eval + +import pandas as pd +from tabulate import tabulate + + +import logging +logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + + +from huggingface_hub import login + +def read_and_print_results(filepath='results.csv'): + """ + Reads the CSV file containing evaluation results and prints them in a formatted table. + """ + if not os.path.exists(filepath): + print(f"File '{filepath}' not found.") + return + + df = pd.read_csv(filepath) + print(tabulate(df, headers='keys', tablefmt='psql', showindex=False)) + +def save_results_to_csv(results, attn_topk, filepath='eval_results.csv'): + """ + Extracts benchmark accuracies from results and saves them along with the attn_topk config. + + Parameters: + results: dict, evaluation results with structure results['results'][]['acc,none'] + attn_topk: float, the attention top-k value used for this run + filepath: str, CSV file to write to (appends if it exists) + """ + # Build a dictionary row with attn_topk and each benchmark's accuracy + row = {'attn_topk': attn_topk} + for benchmark, data in results['results'].items(): + # Default to None if the key is missing + row[benchmark] = data.get('acc,none', None) + + # Check if file exists to decide on writing header + file_exists = os.path.isfile(filepath) + with open(filepath, 'a', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=row.keys()) + if not file_exists: + writer.writeheader() + writer.writerow(row) + +def _update_model_attn_sparsity(model, attn_th): + num_layers = model.config.num_hidden_layers + + # Use the 'decoder' attribute if it exists; otherwise use model.model.layers + layers = model.model.decoder.layers if hasattr(model.model, 'decoder') else model.model.layers + attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=attn_th) + + for i in range(num_layers): + layers[i].self_attn.sp_threshold = attn_sparsity_map[i] + + average_act = compute_average_activation(attn_sparsity_map) + print(f"Attention sparsity {attn_th}: {attn_sparsity_map}") + print(f"Average activation: {average_act:.2f}") + + return model + +def _evaluate_model(model, tokenizer, benchmarks: list, device: str, batch_size: int = 8): + logging.info("Evaluating on benchmarks: %s", benchmarks) + lm_obj = HFLM( + pretrained=model, + tokenizer=tokenizer, + device=device, + batch_size=batch_size + ) + task_manager = TaskManager() + num_fewshot = 5 + print(f"Number of fewshot examples: {num_fewshot}") + results = lm_eval.simple_evaluate( + model=lm_obj, + tasks=benchmarks, + num_fewshot=num_fewshot, # change this + task_manager=task_manager + ) + logging.info("Evaluation complete.") + for benchmark, benchmark_results in results['results'].items(): + logging.info("Results for %s: %s", benchmark.upper(), benchmark_results) + return results + +def _load_model(model_name, num_layers, device, args): + if args.mode == 'sparse': + logging.info("Loading sparse model...") + sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th= args.attn_topk, mlp_th=args.mlp_topk) + + if args.model_index <=8: + # OPT models + model = SparseOPTForCausalLM.from_pretrained( + model_name, + device_map=device, + torch_dtype=torch.float16, + sp_thresholds=sp_thresholds.activation_threshold, + mlp_thresholds=sp_thresholds.mlp_threshold, + attn_implementation="flash_attention_2" + ) + logging.info("Loading router states...") + mlp_router_state = create_hf_mlp_router_state_dict(args.mlp_ckpt_dir) + mha_router_state = create_hf_mha_router_state_dict(args.attn_ckpt_dir) + model_state = model.state_dict() + model_state.update(mlp_router_state) + model_state.update(mha_router_state) + model.load_state_dict(model_state) + logging.info("Sparse model loaded with routers!") + + # load topk values for mlp and attn here + # mlp_topk_lookup = build_mlp_topk_lookup(args.batch_stats_dir, args.batch_size, args.delta) + # mlp_topk_lookup = build_mlp_topk_lookup(args.batch_stats_dir, 1, args.delta) + mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, 1) + + _update_hf_mlp_topk(model, mlp_topk_lookup) + # print("MLP topk values updated.") + # print("MLP topk values: ", mlp_topk_lookup) + logging.info("Using MLP topk values: %s", mlp_topk_lookup) + + # print("Using delta value: ", args.delta) + + # the first layer should use dense attention + model.model.decoder.layers[0].self_attn.sp_threshold = 1.0 + else: + # Llama models + + if not args.static_thresholds: + attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=args.attn_topk) + sp_thresholds.load_thresholds(attn_sparsity_map) + average_act = compute_average_activation(attn_sparsity_map) + print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}") + print(f"Average activation: {average_act:.2f}") + + model = SparseLlamaForCausalLM.from_pretrained(model_name, + device_map = device, + torch_dtype=torch.float16, + sp_thresholds = sp_thresholds.activation_threshold, + attn_implementation="flash_attention_2") + logging.info("Loading router states...") + model_state = model.state_dict() + attn_router_states = create_hf_attn_router_state_dict(args.attn_ckpt_dir) + model_state.update(attn_router_states) + model.load_state_dict(model_state) + logging.info("Sparse model loaded with routers!") + + # the first layer should use dense attetnion + _update_model_attn_thresholds(model, args.attn_topk) + + # load topk values for mha here + # TODO: create a function to update the topk values for mha + + elif args.mode == 'sparse_attn': + logging.info("Loading model with sparse attention") + sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=args.attn_topk) + + if not args.static_thresholds: + attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=args.attn_topk) + sp_thresholds.load_thresholds(attn_sparsity_map) + average_act = compute_average_activation(attn_sparsity_map) + print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}") + print(f"Average activation: {average_act:.2f}") + + if args.model_index <= 8: + # opt models + model = SparseOPTTopKAttn.from_pretrained(model_name, device_map = device, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2") + else: + # llama models + model = SparseLlamaTopKAttn.from_pretrained(model_name, device_map = device, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2") + else: + logging.info("Loading dense model...") + model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.float16) + return model + +def arg_parser(): + parser = argparse.ArgumentParser(description='Inference benchmarking') + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--model_index', type=int, default=5) + parser.add_argument('--print_results', type=bool, default=True) + parser.add_argument('--results_dir', type=str, default='results/eval') + parser.add_argument('--device', type=int, default=100) + parser.add_argument('--mode', type=str, default='sparse', choices=['sparse', 'dense', 'sparse_attn']) + parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model') + parser.add_argument('--mlp_topk', type=int, default=2048, help='MLP topk for sparse model') + parser.add_argument('--delta', type=int, default=128, help='Delta value for MLP topk calculation') + parser.add_argument('--mlp_ckpt_dir', type=str, default='') + parser.add_argument('--attn_ckpt_dir', type=str, default='') + parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router') + parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds') + parser.add_argument('--benchmark', type=str, default='all', help='Options: all, or a single benchmark name') + parser.add_argument('--note', type=str, default='', help='Note to add to the results filename') + parser.add_argument('--static_thresholds', type=bool, default=True, help='Use static thresholds for attention layers') + return parser.parse_args() + +if __name__ == "__main__": + args = arg_parser() + + login_token = None # insert your token here + assert login_token is not None, "Please provide a valid Hugging Face token." + login(token=login_token) + + # Setup logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + model_name = MODELS[args.model_index - 1] + # print(f"Evaluating Model: {model_name}") + logging.info("Evaluating Model: %s", model_name) + logging.info("Mode: %s", args.mode) + + num_layers = CONFIGS[model_name]['num_layer'] + device = 'auto' if args.device == 100 else f'cuda:{args.device}' + + # Load tokenizer and model + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + model = _load_model(model_name, num_layers, device, args) + model.eval() + + # Determine benchmarks to evaluate + if args.benchmark == 'all': + benchmarks = ["piqa", "winogrande", "copa", "rte", "openbookqa", "arc_easy", "arc_challenge", "mmlu", "hellaswag"] + else: + benchmarks = [args.benchmark] + + model_name_clean = extract_model_name(model_name) + + if args.data_collection: + # make sure the model is not dense + assert args.mode != 'dense', "Data collection is only available for sparse models" + logging.info("Data collection mode enabled.") + if args.mode == 'sparse': + filepath = f"{args.results_dir}/eval_results_{model_name_clean}_sparse_sweep_dpsd.csv" + else: # sparse_attn + filepath = f"{args.results_dir}/eval_results_{model_name_clean}_attn_sweep_dpsd.csv" + + if args.note != '': + filepath = filepath.replace('.csv', f"_{args.note}.csv") + # attn_topk_values = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] # MHA + attn_topk_values = [0.9, 0.8, 0.7, 0.6, 0.4, 0.3, 0.2, 0.1] + # attn_topk_values = [7/8, 6/8, 5/8, 4/8, 3/8, 2/8, 1/8] # GQA + for attn_topk in attn_topk_values: + logging.info("Evaluating with attention top-k value: %s", attn_topk) + if args.static_thresholds: + _update_model_attn_thresholds(model, attn_topk, mode=args.mode) + else: + _update_model_attn_sparsity(model, attn_topk) + + results = _evaluate_model( + model=model, + tokenizer=tokenizer, + benchmarks=benchmarks, + device=device, + batch_size=args.batch_size + ) + save_results_to_csv(results, attn_topk, filepath = filepath) + else: + logging.info("Evaluating with attention top-k value: %s", args.attn_topk) + if args.mode == 'dense': + filepath = f"{args.results_dir}/eval_results_{model_name_clean}_dense.csv" + elif args.mode == 'sparse_attn': + filepath = f"{args.results_dir}/eval_results_{model_name_clean}_sparse_attn_{args.attn_topk}_dpsd.csv" + else: + filepath = f"{args.results_dir}/eval_results_{model_name_clean}_test_attn_{args.attn_topk}_dpsd.csv" + if args.note != '': + filepath = filepath.replace('.csv', f"_{args.note}.csv") + results = _evaluate_model( + model=model, + tokenizer=tokenizer, + benchmarks=benchmarks, + device=device, + batch_size=args.batch_size + ) + save_results_to_csv(results, args.attn_topk, filepath = filepath) + + if args.print_results: + read_and_print_results(filepath=filepath) \ No newline at end of file diff --git a/HybridTensor/benchmarks/model_perplexity.py b/HybridTensor/benchmarks/model_perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..199c69621d80607e1bd52b9355e741e0ec76d0c8 --- /dev/null +++ b/HybridTensor/benchmarks/model_perplexity.py @@ -0,0 +1,165 @@ +# python -m HybridTensor.benchmarks.model_perplexity --model_index 14 --batch_size 4 --max_length 512 --attn_th 1 --static_thresholds True + +import sys +import math +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig + +from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM as SparseOPTTopkAttn +from hf_models.llama.modeling_sparse_llama_mha_topk import SparseLlamaForCausalLM as SparseLlamaTopKAttn +from HybridTensor.utils.activations import ActivationThresholds, identify_model_type, MODELS, CONFIGS +from HybridTensor.utils.utils import extract_model_name, compute_perplexity +import argparse +from datasets import load_dataset +import json +from torch.utils.data import DataLoader +from tqdm import tqdm +import pandas as pd + + +from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import (_update_model_attn_thresholds, + build_data_loader, + compute_sparse_perplexity, + compute_perplexity_data_collection, + display_model_menu, + _interactive_mode, + arg_parser, + ) + + +results_dir = "results/activations" + +def compute_attn_layer_sparsity(model_name, min_th, critical_th, attn_sparsity): + # Get model configuration + # model_name = MODELS[model_index - 1] + model_config = CONFIGS[model_name] + num_layers = model_config['num_layer'] + + # Load the importance scores from the file specified in the configuration + file_path = model_config['layer_imp'] + with open(file_path, 'r') as f: + attn_layer_imp = json.load(f) + layer_importance = attn_layer_imp['importance_scores'] + + # Classify layers as critical or sparse + critical_layers = [i for i, imp in enumerate(layer_importance) if imp >= critical_th] + sparse_layers = [i for i, imp in enumerate(layer_importance) if imp < critical_th] + + # Calculate total sparse importance and the attention value + sum_sparse_importance = sum(layer_importance[i] for i in sparse_layers) + attn_val = attn_sparsity * len(sparse_layers) + + # Compute the sparsity map per layer + layer_sparsity_map = {} + for layer_idx in range(num_layers): + if layer_idx in critical_layers: + layer_sparsity_map[layer_idx] = 1.0 # Fully dense for critical layers + else: + if sum_sparse_importance > 0: + raw_fraction = (layer_importance[layer_idx] / sum_sparse_importance) * attn_val + else: + raw_fraction = attn_sparsity + # Clamp the fraction between min_th and 1.0 + fraction = max(raw_fraction, min_th) + fraction = min(fraction, 1.0) + layer_sparsity_map[layer_idx] = fraction + + return layer_sparsity_map + +def compute_average_activation(layer_sparsity_map): + """ + Computes the average activation for each layer based on the sparsity map. + """ + total_activation = 0.0 + for layer_idx, fraction in layer_sparsity_map.items(): + total_activation += fraction + + average_activation = total_activation / len(layer_sparsity_map) + return average_activation + +def compute_sparse_perplexity(model_name='facebook/opt-125m', + dataset_name='wikitext', + dataset_config='wikitext-2-raw-v1', + batch_size=8, + max_length=512, + attn_th=0.0, + static_thresholds=True, + device_map="cuda:0"): + # Set device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f'Using device: {device}') + + # load the activation thresholds + num_layers = CONFIGS[model_name]['num_layer'] + sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=attn_th) + + print(f"Static attention activations: {sp_thresholds.activation_threshold}") + if not static_thresholds: + # act_threshold_filepath = CONFIGS[model_name]['sp_config'] + attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=attn_th) + sp_thresholds.load_thresholds(attn_sparsity_map) + average_act = compute_average_activation(attn_sparsity_map) + print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}") + print(f"Average activation: {average_act:.2f}") + + # Load tokenizer and model + # tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + model_type = identify_model_type(model_name) + if model_type == 'OPT': + print(f"Loading OPT model: {model_name}") + model = SparseOPTTopkAttn.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2") + elif model_type == 'Llama': + print(f"Loading Llama model: {model_name}") + model = SparseLlamaTopKAttn.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2") + model.eval() + + # # Load dataset + dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length) + perplexity = compute_perplexity(model, dataloader, device) + return perplexity + + +def arg_parser(): + parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation') + parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate') + parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation') + parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length') + parser.add_argument('--attn_th', type=float, default=0.0, help='Activation threshold for attention layers') + parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds') + parser.add_argument('--device_map', type=str, default='auto', help='Device to use for evaluation') + parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection') + parser.add_argument('--static_thresholds', type=bool, default=False, help='Use static thresholds for attention layers') + + return parser.parse_args() + +def main(): + """ + Main function to execute the perplexity computation with user-selected OPT model. + """ + print("=== OPT Models Perplexity Evaluation ===\n") + args = arg_parser() + + if args.interactive: + selected_model, batch_size, max_length, data_collection, device_map, attn_th = _interactive_mode() + + else: + selected_model, batch_size, max_length, data_collection, device_map, attn_th = MODELS[args.model_index-1], args.batch_size, args.max_length, args.data_collection, args.device_map, args.attn_th + print(f"Selected model: {selected_model}, batch size: {batch_size}, max length: {max_length}, attn_th: {attn_th}, data_collection: {data_collection}, device: {device_map}") + + if data_collection: + print("\nStarting data collection...\n") + compute_perplexity_data_collection(model_name=selected_model, batch_size=batch_size, max_length=max_length, device_map=device_map) + print("\nData collection complete.\n") + + else: + print("\nStarting perplexity computation...\n") + perplexity = compute_sparse_perplexity(model_name=selected_model, batch_size=batch_size, max_length=max_length, + attn_th=attn_th, + device_map=device_map, + static_thresholds=args.static_thresholds) + print(f"\n=== Perplexity Results ===") + print(f"Model: {selected_model}") + print(f"Perplexity: {perplexity:.2f}\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/HybridTensor/benchmarks/opt_attn_sparse_topk_perplexity.py b/HybridTensor/benchmarks/opt_attn_sparse_topk_perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..717bcfc69b67a761dfefcc0902c1103be1135b13 --- /dev/null +++ b/HybridTensor/benchmarks/opt_attn_sparse_topk_perplexity.py @@ -0,0 +1,264 @@ +import sys + +import math +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +# from hf_models.opt.modeling_sparse_opt import SparseOPTForCausalLM +from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM +from HybridTensor.utils.activations import ActivationThresholds, MODELS, CONFIGS +from HybridTensor.utils.utils import extract_model_name, compute_perplexity + +import argparse +from datasets import load_dataset + +from torch.utils.data import DataLoader +from tqdm import tqdm +import pandas as pd + +results_dir = "results/activations" + +def _update_model_attn_thresholds(model, attn_th, mode='sparse'): + num_layers = model.config.num_hidden_layers + + # Use the 'decoder' attribute if it exists; otherwise use model.model.layers + layers = model.model.decoder.layers if hasattr(model.model, 'decoder') else model.model.layers + + for i in range(num_layers): + layers[i].self_attn.sp_threshold = attn_th + + # For non-sparse attention, layer 0 should use a threshold of 1.0 + # if mode != 'sparse_attn': + # layers[0].self_attn.sp_threshold = 1.0 + layers[0].self_attn.sp_threshold = 1.0 + + return model + + +def build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length, split='test'): + """ + Build a DataLoader for the specified dataset. + + Args: + - model_name (str): The Hugging Face identifier of the model. + - dataset_name (str): The name of the dataset. + - dataset_config (str): The configuration of the dataset. + - batch_size (int): The batch size for the DataLoader. + - max_length (int): The maximum sequence length. + - split (str): The split of the dataset to use (default='test'). options: 'train', 'validation', 'test' + + Returns: + - dataloader (DataLoader): The DataLoader for the specified dataset. + """ + # Load tokenizer and model + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + # Load dataset + dataset = load_dataset(dataset_name, dataset_config, split=split) + dataset = dataset.filter(lambda x: len(x["text"]) >= max_length) + + # Tokenize the dataset + def tokenize_function(examples): + return tokenizer(examples['text'], return_special_tokens_mask=True, truncation=True, max_length=max_length) + + tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text']) + + # Create DataLoader + def collate_fn(batch): + input_ids = [torch.tensor(example['input_ids']) for example in batch] + attention_mask = [torch.tensor(example['attention_mask']) for example in batch] + + # Pad sequences + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0) + + return {'input_ids': input_ids, 'attention_mask': attention_mask} + + dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) + + return dataloader + +def compute_sparse_perplexity(model_name='facebook/opt-125m', dataset_name='wikitext', dataset_config='wikitext-2-raw-v1', batch_size=8, max_length=512, attn_th=0.0, static_thresholds=True, device_map="cuda:0"): + # Set device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f'Using device: {device}') + + # load the activation thresholds + num_layers = CONFIGS[model_name]['num_layer'] + + sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=attn_th) + + if not static_thresholds: + act_threshold_filepath = CONFIGS[model_name]['sp_config'] + sp_thresholds.load_thresholds(act_threshold_filepath) + print(f'Activation thresholds loaded from {act_threshold_filepath}') + + print(sp_thresholds.activation_threshold) + + # Load tokenizer and model + # tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + model = SparseOPTForCausalLM.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2") + model.eval() + + # # Load dataset + dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length) + + perplexity = compute_perplexity(model, dataloader, device) + + return perplexity + +def compute_perplexity_data_collection(model_name='facebook/opt-125m', dataset_name='wikitext', dataset_config='wikitext-2-raw-v1', batch_size=8, max_length=512, device_map="cuda:0"): + # Set device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f'Using device: {device}') + + # Load dataset + dataset = load_dataset(dataset_name, dataset_config, split='test') + dataset = dataset.filter(lambda x: len(x["text"]) >= 512) + dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length) + + attn_thresholds = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] + # attn_thresholds = [1, 0.5, 0.2] + + print(f"Computing perplexity for the following attention thresholds: {attn_thresholds}") + + # load the model + num_layers = CONFIGS[model_name]['num_layer'] + + sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=0.1) + model = SparseOPTForCausalLM.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2") + model.eval() + + results = [] + for attn_th in attn_thresholds: + + print(f'Computing perplexity for attn top k: {attn_th}') + + # update the model with new threshold + model = _update_model_attn_thresholds(model, attn_th) + + # compute and store the perplexity + perplexity = compute_perplexity(model, dataloader, device) + print(f'Perplexity: {perplexity:.2f}\n') + results.append({ + 'model': model_name, + 'top_k': attn_th, + 'perplexity': perplexity + }) + + + # save the results to a csv file + results_df = pd.DataFrame(results) + model_name_str = extract_model_name(model_name) + + # save the results to a csv file in the results directory + results_df.to_csv(f'{results_dir}/sparse_perplexity_results_{model_name_str}_topk.csv', index=False) + + +def display_model_menu(): + """ + Displays a numbered menu of available OPT models and prompts the user to make a selection. + + Returns: + - selected_model (str): The Hugging Face identifier of the selected model. + """ + print("Available OPT Models:") + for idx, model in enumerate(MODELS, 1): + print(f"{idx}. {model}") + + while True: + try: + choice = input("\nEnter the number corresponding to the model you want to evaluate (e.g., 1): ") + if choice.lower() in ['q', 'quit', 'exit']: + print("Exiting the program.") + sys.exit(0) + choice = int(choice) + if 1 <= choice <= len(MODELS): + selected_model = MODELS[choice - 1] + print(f"\nYou have selected: {selected_model}\n") + return selected_model + else: + print(f"Please enter a number between 1 and {len(MODELS)}.") + except ValueError: + print("Invalid input. Please enter a valid number.") + + +def _interactive_mode(): + selected_model = display_model_menu() + + # Optional: Allow user to adjust batch size and max sequence length + try: + batch_size_input = input("Enter batch size (default=8): ").strip() + batch_size = int(batch_size_input) if batch_size_input else 8 + except ValueError: + print("Invalid input for batch size. Using default value of 8.") + batch_size = 8 + + max_length = 512 + + try: + data_collection = input("Do you want to collect data for different activation thresholds? (y/n): ").strip() + data_collection = True if data_collection.lower() == 'y' else False + except ValueError: + print("Invalid input for data collection. Using default value of False.") + data_collection = False + + + # select device + device_map = input("Enter device (cuda:0/auto) [default=cuda:0]: ").strip() + if not device_map: + device_map = "cuda:0" + + # select attention threshold + attn_th = 0.0 + if not data_collection: + try: + attn_th = input("Enter activation threshold for attention layers: ").strip() + attn_th = float(attn_th) if attn_th else 0.0 + except ValueError: + print("Invalid input for attention threshold. Using default value of 0.") + attn_th = 0.0 + + return selected_model, batch_size, max_length, data_collection, device_map, attn_th + + +def arg_parser(): + parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation') + parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate') + parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation') + parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length') + parser.add_argument('--attn_th', type=float, default=0.0, help='Activation threshold for attention layers') + parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds') + parser.add_argument('--device_map', type=str, default='cuda:0', help='Device to use for evaluation') + parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection') + + return parser.parse_args() + +def main(): + """ + Main function to execute the perplexity computation with user-selected OPT model. + """ + print("=== OPT Models Perplexity Evaluation ===\n") + args = arg_parser() + + if args.interactive: + selected_model, batch_size, max_length, data_collection, device_map, attn_th = _interactive_mode() + + else: + selected_model, batch_size, max_length, data_collection, device_map, attn_th = MODELS[args.model_index-1], args.batch_size, args.max_length, args.data_collection, args.device_map, args.attn_th + print(f"Selected model: {selected_model}, batch size: {batch_size}, max length: {max_length}, attn_th: {attn_th}, data_collection: {data_collection}, device: {device_map}") + + if data_collection: + print("\nStarting data collection...\n") + compute_perplexity_data_collection(model_name=selected_model, batch_size=batch_size, max_length=max_length, device_map=device_map) + print("\nData collection complete.\n") + + else: + print("\nStarting perplexity computation...\n") + perplexity = compute_sparse_perplexity(model_name=selected_model, batch_size=batch_size, max_length=max_length, attn_th=attn_th, device_map=device_map) + print(f"\n=== Perplexity Results ===") + print(f"Model: {selected_model}") + print(f"Perplexity: {perplexity:.2f}\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/HybridTensor/benchmarks/select_block_decode.py b/HybridTensor/benchmarks/select_block_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee14f572f03a38d73f1bf1082bf48f0629a20d3 --- /dev/null +++ b/HybridTensor/benchmarks/select_block_decode.py @@ -0,0 +1,218 @@ +import torch + +from tests.test_select_block import create_block, Config, SparseConfig +import csv +import time +import torch +import torch.nn as nn +from flash_attn.utils.generation import InferenceParams +from HybridTensor.utils.utils import arg_parser, _get_device, sparse_index, generate_random_BH_index, get_gpu_name +from HybridTensor.utils.profiling import cuda_profiler +import math +from tqdm import tqdm + +def run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype): + config = Config() + sp_config = SparseConfig() + sp_config.attn_topk = attn_topk + + config.hidden_size = args.in_features + config.num_attention_heads = args.in_features // 128 + config.use_heuristic = False # use pre-compiled heuristic or complie new one during runtime + + # Test create_block + sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype) + sparse_block.eval() + sparse_block.mlp_topk = index_size + + regular_config = config + regular_config.att_sparse = False + regular_config.mlp_sparse = False + regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype) + regular_block.eval() + + # inference simulation with select block + max_seqlen = seq_len + 16 + max_batch_size = batch_size + in_features = args.in_features + head_dim = 128 + + inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size) + process_group = None + sequence_parallel = False + + # for testing and debugging + heads = config.num_attention_heads + selected_heads = heads // 2 + + # Create a static index vector (length equals total columns in B). + total_neurons = args.in_features * 4 + test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32) + active_indices = sparse_index(args.index_size, total_neurons)[0] + test_index_vec[:args.index_size] = active_indices + if args.index_size < total_neurons: + test_index_vec[args.index_size:] = 0 # Fill the rest with dummy values. + + # test_index_vec = sparse_index(args.in_features, args.in_features*4)[0].cuda() + test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads) + test_index_size = args.index_size + + mixer_kwargs = ( + {"seqlen": seq_len} + if process_group is not None and sequence_parallel + else {} + ) + if inference_params is not None: + mixer_kwargs["inference_params"] = inference_params + + with torch.no_grad(): + # prefill stage + original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16) + + # Test prefill + output_sparse = sparse_block(original_seq, mixer_kwargs=mixer_kwargs) + output_regular = regular_block(original_seq, mixer_kwargs=mixer_kwargs) + + # need to update inference_params to reflect the new sequence length + mixer_kwargs["inference_params"].seqlen_offset = seq_len + + # Decode stage + input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16) + + out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs) + + mixer_kwargs["inference_params"].seqlen_offset = seq_len + + out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs) + + # mesure decode stage time in ms + # print("Without CUDA Graphs") + # out_decode_regular, regular_time = cuda_profiler(regular_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2) + # print(f"Regular time: {regular_time} ms") + + # out_decode_sparse, sparse_time = cuda_profiler(sparse_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2) + # print(f"Sparse time: {sparse_time} ms") + + # speedup = regular_time / sparse_time + # print(f"Speedup: {speedup}") + + # --- CUDA Graph Capture for Decode Stage --- + # Allocate static buffer for regular block (shape assumed fixed) + input_x_static = input_x.clone() + output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype) + + # Capture regular block graph + _ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) + torch.cuda.synchronize() + graph_regular = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph_regular): + res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) + if isinstance(res, tuple): + res = res[0] + output_regular_static.copy_(res) + + # For the sparse block, run a dummy call to determine its output shape. + # Also, reset the inference parameter to ensure consistent behavior. + mixer_kwargs["inference_params"].seqlen_offset = seq_len + temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) + if isinstance(temp, tuple): + temp = temp[0] + # print("Captured sparse block output shape:", temp.shape) + # Allocate static buffer matching the dummy run's shape. + output_sparse_static = torch.empty_like(temp) + # print("output_sparse_static shape:", output_sparse_static.shape) + torch.cuda.synchronize() + + mixer_kwargs["inference_params"].seqlen_offset = seq_len + graph_sparse = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph_sparse): + res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) + if isinstance(res, tuple): + res = res[0] + output_sparse_static.copy_(res) + + # Warmup CUDA Graph replays + for _ in range(5): + graph_regular.replay() + graph_sparse.replay() + torch.cuda.synchronize() + + # --- Measure CUDA Graph Replay Latency --- + num_replays = 10 + + start = time.time() + for _ in range(num_replays): + graph_regular.replay() + torch.cuda.synchronize() + regular_graph_time = (time.time() - start) * 1000 / num_replays + + start = time.time() + for _ in range(num_replays): + graph_sparse.replay() + torch.cuda.synchronize() + sparse_graph_time = (time.time() - start) * 1000 / num_replays + speedup = regular_graph_time / sparse_graph_time + # print() + # print("With CUDA Graphs") + # print(f"Regular block time (CUDA Graphs): {regular_graph_time} ms") + # print(f"Sparse block time (CUDA Graphs): {sparse_graph_time} ms") + # print(f"Speedup (CUDA Graphs): {speedup}") + + return regular_graph_time, sparse_graph_time, speedup + +if __name__ == "__main__": + + args = arg_parser() + device = _get_device(0) + dtype = torch.float16 + gpu_name = get_gpu_name() + + # Parameter grids. + # batch_sizes = [1, 4, 8, 16] + # seq_lengths = [128, 512] + # index_sizes = [512, 1024, 2048, 4096] + # attn_topks = [0.3, 0.4, 0.5] + + batch_sizes = [1, 8, 16, 32] + seq_lengths = [1024, 2048] + # index_sizes = [512, 1024, 2048, 4096, 8192] + index_size_p = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5] + total_neurons = args.in_features * 4 + + # Calculate initial index_size values + index_sizes = [int(total_neurons * i) for i in index_size_p] + + # Round up to the nearest multiple of 128 if necessary + index_sizes = [math.ceil(size / 128) * 128 if size % 128 != 0 else size for size in index_sizes] + + attn_topks = [0.3, 0.4, 0.5] + + # Calculate total number of simulations. + total_runs = len(batch_sizes) * len(seq_lengths) * len(index_sizes) * len(attn_topks) + output_file = f"results/simulations/{gpu_name}_select_block_{args.in_features}_inference_sim.csv" + + with open(output_file, mode='w', newline='') as csv_file: + fieldnames = ["in_features", "batch_size", "seq_len", "index_size", "neuron_activation", "attn_topk", + "regular_graph_time_ms", "sparse_graph_time_ms", "speedup"] + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) + writer.writeheader() + + # Iterate over all combinations with tqdm progress bar. + for batch_size in tqdm(batch_sizes, desc="Batch Sizes"): + for seq_len in seq_lengths: + for index_size in index_sizes: + for attn_topk in attn_topks: + reg_time, spa_time, speedup = run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype) + writer.writerow({ + "in_features": args.in_features, + "batch_size": batch_size, + "seq_len": seq_len, + "index_size": index_size, + "neuron_activation": index_size / total_neurons, + "attn_topk": attn_topk, + "regular_graph_time_ms": reg_time, + "sparse_graph_time_ms": spa_time, + "speedup": speedup + }) + csv_file.flush() + print(f"Simulation complete. Results saved to {output_file}") \ No newline at end of file diff --git a/HybridTensor/models/__pycache__/create_sparse_model.cpython-310.pyc b/HybridTensor/models/__pycache__/create_sparse_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf10666e9d7f6c362fc572234c6b39b84584b4fd Binary files /dev/null and b/HybridTensor/models/__pycache__/create_sparse_model.cpython-310.pyc differ diff --git a/HybridTensor/models/__pycache__/create_sparse_model.cpython-39.pyc b/HybridTensor/models/__pycache__/create_sparse_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dc1b60de6cb3a52ddef23f92ca7739cb2a02716 Binary files /dev/null and b/HybridTensor/models/__pycache__/create_sparse_model.cpython-39.pyc differ diff --git a/HybridTensor/models/__pycache__/helper.cpython-310.pyc b/HybridTensor/models/__pycache__/helper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c2e75182525030ec9066210fd23eb425d99f8f7 Binary files /dev/null and b/HybridTensor/models/__pycache__/helper.cpython-310.pyc differ diff --git a/HybridTensor/models/__pycache__/helper.cpython-39.pyc b/HybridTensor/models/__pycache__/helper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9dbfa1c123cae3a9b3a47e3b2749b90139eef38 Binary files /dev/null and b/HybridTensor/models/__pycache__/helper.cpython-39.pyc differ diff --git a/HybridTensor/models/__pycache__/llama.cpython-39.pyc b/HybridTensor/models/__pycache__/llama.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d8f3a349f38d9b8037162c1b5ea39b75a92b159 Binary files /dev/null and b/HybridTensor/models/__pycache__/llama.cpython-39.pyc differ diff --git a/HybridTensor/models/__pycache__/opt.cpython-310.pyc b/HybridTensor/models/__pycache__/opt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c77ce4a1af155cfee6925f64bdb5fc0f2096915f Binary files /dev/null and b/HybridTensor/models/__pycache__/opt.cpython-310.pyc differ diff --git a/HybridTensor/models/__pycache__/opt.cpython-39.pyc b/HybridTensor/models/__pycache__/opt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adfb1940aea8a019e7047ef58229212c7e04545e Binary files /dev/null and b/HybridTensor/models/__pycache__/opt.cpython-39.pyc differ diff --git a/HybridTensor/models/create_sparse_model.py b/HybridTensor/models/create_sparse_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ffaf8de47b937295b245a5588c17863144dcdbd8 --- /dev/null +++ b/HybridTensor/models/create_sparse_model.py @@ -0,0 +1,854 @@ +import math +import torch +import torch.nn as nn +from functools import partial + +from einops import rearrange + +from transformers import GPT2Config +from collections import namedtuple +from HybridTensor.modules.SelectiveMHA import SMHA, SelectMHA, ParallelSelectMHA, MHARouter, ParallelMHARouter +from HybridTensor.modules.SelectiveMLP import SelectiveMLP, ParallelSelectiveMLP, MLPRouter, ParallelMLPRouter +from HybridTensor.modules.SelectiveBlock import SelectBlock +# from HybridTensor.modules.SelectiveBlock_v1 import SelectBlock +import torch.nn.functional as F +from flash_attn.utils.distributed import ( + all_gather, + all_gather_raw, + get_dim_for_local_rank, + sync_shared_params, +) + +from collections.abc import Sequence +from flash_attn.modules.mha import MHA, ParallelMHA +from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP, GatedMlp, ParallelGatedMlp, Mlp, ParallelMLP +from flash_attn.ops.activations import sqrelu_fwd +from flash_attn.modules.block import Block + +try: + from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm +except ImportError: + layer_norm_fn, RMSNorm = None, None + +from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings +from flash_attn.utils.distributed import sync_shared_params, all_gather_raw +from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.utils.generation import GenerationMixin +from flash_attn.models.opt import remap_state_dict_hf_opt + +try: + from flash_attn.ops.fused_dense import ColumnParallelLinear +except ImportError: + ColumnParallelLinear = None + +try: + from flash_attn.ops.triton.mlp import FusedDenseSqreluDense +except ImportError: + FusedDenseSqreluDense = None + +try: + from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm +except ImportError: + layer_norm_fn, RMSNorm = None, None + +from HybridTensor.models.helper import remap_state_dict_gpt2, shard_state_dict_tp + +def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0 + softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power)) + softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0) + if config.scale_attn_by_inverse_layer_idx: + assert layer_idx is not None + softmax_scale /= float(layer_idx + 1) + dwconv = getattr(config, "attn_dwconv", False) + if dwconv: + assert process_group is None, "TensorParallel MHA does not support dwconv yet" + qkv_proj_bias = getattr(config, "qkv_proj_bias", True) + out_proj_bias = getattr(config, "out_proj_bias", True) + rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim) + rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0) + rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None) + rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False) + use_alibi = getattr(config, "use_alibi", False) + use_triton = getattr(config, "use_triton", True) # toggle cuda or triton decode kernels + window_size = getattr(config, "window_size", (-1, -1)) + use_flash_attn = getattr(config, "use_flash_attn", False) + fused_bias_fc = getattr(config, "fused_bias_fc", False) + if not fused_bias_fc: + assert process_group is None, "TensorParallel MHA requires fused_bias_fc" + + mlp_sparse = getattr(config, "mlp_sparse", False) + att_sparse = getattr(config, "att_sparse", False) + num_heads = getattr(config, "num_attention_heads", None) + n_head_kv = getattr(config, "n_head_kv", num_heads) + + if num_heads != n_head_kv: + att_sparse = False + + if process_group is None: + mha_cls = SMHA # SelectMHA if att_sparse else MHA + else: + mha_cls = ParallelSelectMHA if att_sparse else ParallelMHA + + # mha_cls = SelectMHA if process_group is None else ParallelSelectMHA + serial_kwargs = ( + {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {} + ) + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", False), + } + if process_group is not None + else {} + ) + num_heads_kv = getattr(config, "n_head_kv", None) + mixer_cls = partial( + mha_cls, + num_heads=config.num_attention_heads, + num_heads_kv=num_heads_kv, + qkv_proj_bias=qkv_proj_bias, + out_proj_bias=out_proj_bias, + dropout=config.attn_pdrop, + softmax_scale=softmax_scale, + causal=True, + layer_idx=layer_idx, + rotary_emb_dim=rotary_emb_dim, + rotary_emb_base=rotary_emb_base, + rotary_emb_scale_base=rotary_emb_scale_base, + rotary_emb_interleaved=rotary_emb_interleaved, + use_alibi=use_alibi, + window_size=window_size, + use_flash_attn=use_flash_attn, + **serial_kwargs, + **parallel_kwargs, + **factory_kwargs, + ) + return mixer_cls + +def create_mlp_cls_old(config, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + fused_mlp = getattr(config, "fused_mlp", False) + if fused_mlp: + assert config.activation_function in [ + "gelu_new", + "gelu_fast", + "gelu_approx", + "gelu_pytorch_tanh", + "relu", + "sqrelu", + ] + assert fused_mlp == True, "Not supported not fused mlp for now" + + mlp_sparse = getattr(config, "mlp_sparse", False) + use_heuristic = getattr(config, "use_heuristic", True) + + mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) + # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer + if isinstance(mlp_checkpoint_lvl, Sequence): + assert layer_idx is not None + mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] + + if fused_mlp: + if FusedMLP is None: + raise ImportError("fused_dense is not installed") + # activation = ( + # "gelu_approx" + # if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"] + # else "relu" + # ) + + if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]: + activation = "gelu_approx" + else: + activation = "relu" # config.activation_function + + if process_group is None: + mlp_cls = SelectiveMLP if mlp_sparse else FusedMLP + else: + mlp_cls = ParallelSelectiveMLP if mlp_sparse else ParallelFusedMLP + + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) + + sparsity_kwargs = ( + { + "use_heuristic": use_heuristic, + } + if mlp_sparse + else {} + ) + + mlp_cls = partial( + mlp_cls, + hidden_features=inner_dim, + activation=activation, + checkpoint_lvl=mlp_checkpoint_lvl, + # layer_idx=layer_idx, + **parallel_kwargs, + **factory_kwargs, + **sparsity_kwargs, + ) + + else: + raise RuntimeError("MLP type not supported") + return mlp_cls + +def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): + """ + Create an MLP class that supports both sparse MLPs (via fused mlp) and GatedMLPs. + If the activation function is one of "glu", "swiglu", or "geglu", then GatedMlp is used + (and mlp_sparse is ignored). Otherwise, fused_mlp is used to decide between sparse and + dense implementations. + """ + from functools import partial + factory_kwargs = {"device": device, "dtype": dtype} + mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True) + mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True) + + + # Check for gated activations + if config.activation_function in ["glu", "swiglu", "geglu"]: + # For gated activations we do not support sparsity yet. + activation = ( + F.sigmoid if config.activation_function == "glu" + else (F.silu if config.activation_function == "swiglu" else F.gelu) + ) + mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp + parallel_kwargs = ( + {"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)} + if process_group is not None else {} + ) + mlp_multiple_of = getattr(config, "mlp_multiple_of", 128) + mlp_cls = partial( + mlp_cls, + hidden_features=config.n_inner, + activation=activation, + bias1=mlp_fc1_bias, + bias2=mlp_fc2_bias, + multiple_of=mlp_multiple_of, + **parallel_kwargs, + **factory_kwargs, + ) + return mlp_cls + + # For non-gated activations: + fused_mlp = getattr(config, "fused_mlp", False) + fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False) + if fused_dense_sqrelu_dense: + assert config.activation_function == "sqrelu", ( + "fused_dense_sqrelu_dense only supports approximate activation_function sqrelu" + ) + assert not (fused_dense_sqrelu_dense and fused_mlp) + + if fused_mlp: + # Ensure valid activation function. + assert config.activation_function in [ + "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu" + ] + # Support checkpoint level (possibly a list) + mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) + if isinstance(mlp_checkpoint_lvl, (list, tuple)): + assert layer_idx is not None + mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] + # Choose activation string. + if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]: + activation = "gelu_approx" + else: + activation = "relu" + # Determine inner dim. + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + mlp_sparse = getattr(config, "mlp_sparse", False) + use_heuristic = getattr(config, "use_heuristic", True) + if process_group is None: + mlp_cls = SelectiveMLP if mlp_sparse else FusedMLP + else: + mlp_cls = ParallelSelectiveMLP if mlp_sparse else ParallelFusedMLP + parallel_kwargs = ( + {"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)} + if process_group is not None else {} + ) + sparsity_kwargs = {"use_heuristic": use_heuristic} if mlp_sparse else {} + mlp_cls = partial( + mlp_cls, + hidden_features=inner_dim, + activation=activation, + checkpoint_lvl=mlp_checkpoint_lvl, + bias1=mlp_fc1_bias, + bias2=mlp_fc2_bias, + **parallel_kwargs, + **factory_kwargs, + **sparsity_kwargs, + ) + return mlp_cls + + elif fused_dense_sqrelu_dense: + if process_group is not None: + assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense" + assert FusedDenseSqreluDense is not None + mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) + if isinstance(mlp_checkpoint_lvl, (list, tuple)): + assert layer_idx is not None + mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] + mlp_cls = partial( + FusedDenseSqreluDense, + hidden_features=config.n_inner, + checkpoint_lvl=mlp_checkpoint_lvl, + **factory_kwargs, + ) + return mlp_cls + + else: + # Non-fused, non-sparse branch. + assert config.activation_function in [ + "gelu", "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu" + ] + if config.activation_function == "relu": + activation = partial(F.relu, inplace=True) + elif config.activation_function == "sqrelu": + activation = sqrelu_fwd + else: + approximate = "tanh" if config.activation_function in [ + "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh" + ] else "none" + activation = partial(F.gelu, approximate=approximate) + mlp_sparse = getattr(config, "mlp_sparse", False) + mlp_cls = Mlp if process_group is None else ParallelMLP + parallel_kwargs = ( + {"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)} + if process_group is not None else {} + ) + mlp_cls = partial( + mlp_cls, + hidden_features=config.n_inner, + activation=activation, + bias1=mlp_fc1_bias, + bias2=mlp_fc2_bias, + **parallel_kwargs, + **factory_kwargs, + ) + return mlp_cls + +def create_mlp_router_cls(config, sp_config = None, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + num_neurons = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + + # this can be made different per layer by adding mlp_low_rank_dim_{layer_idx} in the sp_config + low_rank_dim = getattr(sp_config, "mlp_low_rank_dim", 1024) + + # per layer activation threshold + act_th = getattr(config, "mlp_act_th", 0.5) + + if process_group is None: + mlp_router_cls = MLPRouter + else: + mlp_router_cls = ParallelMLPRouter + + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) + + mlp_router_cls = partial(mlp_router_cls, + low_rank_dim = low_rank_dim, + out_dim = num_neurons, + act_th = act_th, + **parallel_kwargs, + **factory_kwargs) + + return mlp_router_cls + +def create_mha_router_cls(config, sp_config = None, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + num_heads = config.num_attention_heads + n_head_kv = getattr(config, "n_head_kv", num_heads) + if num_heads != n_head_kv: + out_dim = n_head_kv + else: + out_dim = num_heads + + low_rank_dim = getattr(sp_config, "attn_low_rank_dim", 128) # optional, default to 128 + + # per layer activation topk, to make this different per layer, add a different attn_topk_{layer_idx} in the sp_config + attn_topk = getattr(sp_config, "attn_topk", 0.5) + if process_group is None: + mha_router_cls = MHARouter + else: + mha_router_cls = ParallelMHARouter + + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) + + + mha_router_cls = partial(mha_router_cls, + low_rank_dim = low_rank_dim, + out_dim = out_dim, + top_k = attn_topk, + **parallel_kwargs, + **factory_kwargs) + + return mha_router_cls + +def create_block(config, sp_config, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + sequence_parallel = getattr(config, "sequence_parallel", True) + mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) + mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) + + use_rms_norm = getattr(config, "rms_norm", False) + norm_cls = partial( + nn.LayerNorm if not use_rms_norm else RMSNorm, + eps=config.layer_norm_epsilon, + **factory_kwargs, + ) + + # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable + residual_in_fp32 = getattr(config, "residual_in_fp32", False) + resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop + prenorm = getattr(config, "prenorm", True) + parallel_block = getattr(config, "parallel_block", False) + mlp_sparse = getattr(config, "mlp_sparse", False) + att_sparse = getattr(config, "att_sparse", False) + block_sparse = mlp_sparse or att_sparse + + if not parallel_block: + if block_sparse: + mha_router_cls = create_mha_router_cls(config, sp_config, layer_idx, process_group=process_group, **factory_kwargs) if att_sparse else None + mlp_router_cls = create_mlp_router_cls(config, sp_config, layer_idx, process_group=process_group, **factory_kwargs) if mlp_sparse else None + + block = SelectBlock( + config.hidden_size, + mixer_cls, + mlp_cls, + mlp_router = mlp_router_cls, + mha_router = mha_router_cls, + norm_cls=norm_cls, + prenorm=prenorm, + resid_dropout1=resid_dropout1, + resid_dropout2=config.resid_pdrop, + fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), + residual_in_fp32=residual_in_fp32, + sequence_parallel=sequence_parallel and process_group is not None, + mark_shared_params=process_group is not None, + ) + else: + block = Block( + config.hidden_size, + mixer_cls, + mlp_cls, + norm_cls=norm_cls, + prenorm=prenorm, + resid_dropout1=resid_dropout1, + resid_dropout2=config.resid_pdrop, + fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), + residual_in_fp32=residual_in_fp32, + sequence_parallel=sequence_parallel and process_group is not None, + mark_shared_params=process_group is not None, + ) + + else: + # not implemented + raise RuntimeError("ParallelBlock not implemented") + block.layer_idx = layer_idx + return block + + +class GPTPreTrainedModel(nn.Module): + """An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + def __init__(self, config, *inputs, **kwargs): + super().__init__() + if not isinstance(config, GPT2Config): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + ) + ) + self.config = config + + @classmethod + def from_pretrained( + cls, + model_name, + config, + sp_config, + *args, + strict=True, + device=None, + dtype=None, + world_size=1, + rank=0, + **kwargs, + ): + """ + Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + """ + # Instantiate model. + model = cls(config, sp_config, *args, device=device, dtype=dtype, **kwargs) + # Load state_dict in cpu because we already initialized the model in GPU, and we don't + # want extra stuff taking up more GPU memory + state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype) + if model_name.startswith("gpt2"): + state_dict = remap_state_dict_gpt2(state_dict, config) + elif model_name.startswith("facebook/opt"): + state_dict = remap_state_dict_hf_opt(state_dict, config) + else: + raise NotImplementedError(f"Model {model_name} not supported") + if world_size > 1: + state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) + load_return = model.load_state_dict(state_dict, strict=strict) + # logger.info(load_return) + return model + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights( + module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True +): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + nn.init.normal_( + p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) + ) + + +class GPTModel(GPTPreTrainedModel): + def __init__(self, config: GPT2Config, sp_config=None, process_group=None, device=None, dtype=None): + super().__init__(config) + factory_kwargs = {"device": device, "dtype": dtype} + self.process_group = process_group + self.sequence_parallel = getattr(config, "sequence_parallel", True) + assert config.activation_function in [ + "gelu", + "gelu_new", + "gelu_fast", + "gelu_approx", + "relu", + "sqrelu", + "glu", + "swiglu", + "geglu", + ] + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(config.vocab_size / pad_vocab_size_multiple) + * pad_vocab_size_multiple + ) + # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable + self.residual_in_fp32 = getattr(config, "residual_in_fp32", False) + # These 2 options are for OPT-350m + self.prenorm = getattr(config, "prenorm", True) + use_rms_norm = getattr(config, "rms_norm", False) + word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) + + if process_group is None: + self.embeddings = GPT2Embeddings( + config.hidden_size, + vocab_size, + config.max_position_embeddings, + word_embed_proj_dim=word_embed_proj_dim, + **factory_kwargs, + ) + else: + self.embeddings = ParallelGPT2Embeddings( + config.hidden_size, + vocab_size, + config.max_position_embeddings, + process_group=process_group, + sequence_parallel=self.sequence_parallel, + **factory_kwargs, + ) + + + # We change the order of dropout, residual and layer norm: + # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: + # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and + # the main branch (output of MLP). The model definition is unchanged, but the mapping of the + # nn.Dropout probabilities are changed. + # This is for performance reason: we can fuse dropout + add + layer_norm. + self.layers = nn.ModuleList( + [ + create_block( + config, sp_config, layer_idx=i, process_group=process_group, **factory_kwargs + ) + for i in range(config.num_hidden_layers) + ] + ) + + + self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) + if self.fused_dropout_add_ln: + if layer_norm_fn is None: + raise ImportError("Triton is not installed") + if self.prenorm: + self.drop_f = nn.Dropout(config.resid_pdrop) + norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm + # self.ln_f = nn.LayerNorm( + # config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs + # ) + self.ln_f = norm_cls( + config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs + ) + + + if process_group is not None: + for p in self.ln_f.parameters(): + # Mark the norm parameters as "shared_params" so that we sync their values at init. + p._shared_params = True + # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. + if self.sequence_parallel: + p._sequence_parallel = True + + self.apply( + partial( + _init_weights, + n_layer=config.num_hidden_layers, + initializer_range=config.initializer_range, + ) + ) + self.tie_weights() + + self.sparse = False + if config.mlp_sparse or config.att_sparse: + self.sparse = True + + def tie_weights(self): + if self.process_group is not None: + sync_shared_params(self, self.process_group) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return { + i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } + + def forward(self, input_ids, position_ids=None, inference_params=None): + # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen + # dimensions so that we can split on it easily, in case of small batch size. + # Only the attention layers need to know the seqlen. + embedding_kwargs = ( + {"combine_batch_seqlen_dim": True} + if self.process_group is not None and self.sequence_parallel + else {} + ) + hidden_states = self.embeddings( + input_ids, position_ids=position_ids, **embedding_kwargs + ) + residual = None + mixer_kwargs = ( + {"seqlen": input_ids.shape[1]} + if self.process_group is not None and self.sequence_parallel + else {} + ) + if inference_params is not None: + mixer_kwargs["inference_params"] = inference_params + else: + mixer_kwargs["inference_params"] = None + + # else: + for layer in self.layers: + if self.prenorm: + hidden_states, residual = layer( + hidden_states, + residual, + mixer_kwargs=mixer_kwargs, + ) + else: + hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) + + if self.prenorm: + if not self.fused_dropout_add_ln: + dropped = self.drop_f(hidden_states) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + if hidden_states.shape != residual.shape: + hidden_states = hidden_states.view(residual.shape) + + hidden_states = layer_norm_fn( + hidden_states, + self.ln_f.weight, + self.ln_f.bias, + residual=residual, + x1=None, + eps=self.ln_f.eps, + dropout_p=self.drop_f.p if self.training else 0.0, + prenorm=False, + is_rms_norm=isinstance(self.ln_f, RMSNorm) + ) + return hidden_states + + +class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): + def __init__(self, config: GPT2Config, sp_config = None, process_group=None, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(config) + self.process_group = process_group + + self.transformer = GPTModel( + config, sp_config, process_group=process_group, **factory_kwargs + ) + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True) + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(config.vocab_size / pad_vocab_size_multiple) + * pad_vocab_size_multiple + ) + # This option is for OPT-350m + word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) + embed_dim = ( + config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim + ) + if word_embed_proj_dim is not None: + self.project_out = nn.Linear( + config.n_embd, embed_dim, bias=False, **factory_kwargs + ) + else: + self.project_out = None + mup_width_scale = getattr(config, "mup_width_scale", 1.0) + mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0) + self.output_scale = mup_output_multiplier * mup_width_scale + + if process_group is None: + self.lm_head = nn.Linear( + embed_dim, vocab_size, bias=False, **factory_kwargs + ) + else: + if ColumnParallelLinear is None: + raise ImportError("fused_dense_lib is not installed") + self.lm_head = ColumnParallelLinear( + embed_dim, + vocab_size, + process_group, + bias=False, + sequence_parallel=getattr(config, "sequence_parallel", True), + **factory_kwargs, + ) + + self.norm_head = getattr(config, "norm_head", False) + # Initialize weights and apply final processing + self.apply( + partial( + _init_weights, + n_layer=config.num_hidden_layers, + initializer_range=config.initializer_range, + ) + ) + self.tie_weights() + + def tie_weights(self): + if self.tie_word_embeddings: + self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight # llama does not use tied weights + if self.process_group is not None: + sync_shared_params(self, self.process_group) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.transformer.allocate_inference_cache( + batch_size, max_seqlen, dtype=dtype, **kwargs + ) + + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): + """ + input_ids: (batch, seqlen) int tensor + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + num_last_tokens: if > 0, only return the logits for the last n tokens + """ + assert ( + input_ids.ndim == 2 + ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}" + b, slen = input_ids.shape + hidden_states = self.transformer( + input_ids, position_ids=position_ids, inference_params=inference_params + ) + if inference_params is not None: + assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode" + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + if self.output_scale != 1.0: + hidden_states = hidden_states * self.output_scale + if not self.norm_head: + lm_logits = self.lm_head(hidden_states) + else: + lm_head_weight = F.normalize(self.lm_head.weight) + if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel: + hidden_states = all_gather(hidden_states, self.lm_head.process_group) + lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias) + # During inference, we want the full logit for sampling + if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: + lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) + lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b) + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) + + def load_state_dict(self, state_dict, strict=True): + # Remapping from our checkpoints that used a different ordering of layers in the block + # Previous: Attn / MLP -> Dropout -> Add -> LN + # Current: Dropout -> Add -> LN -> Attn / MLP + if "transformer.ln_0.weight" in state_dict: + n_layers = len(self.transformer.layers) + ln_weight = state_dict.pop( + f"transformer.layers.{n_layers - 1}.norm2.weight" + ) + ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias") + state_dict["transformer.ln_f.weight"] = ln_weight + state_dict["transformer.ln_f.bias"] = ln_bias + for l in reversed(range(n_layers)): + ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight") + ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias") + state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight + state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias + if l > 0: + ln_weight = state_dict.pop( + f"transformer.layers.{l - 1}.norm2.weight" + ) + ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias") + state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight + state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias + ln_weight = state_dict.pop("transformer.ln_0.weight") + ln_bias = state_dict.pop("transformer.ln_0.bias") + state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight + state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias + return super().load_state_dict(state_dict, strict=strict) \ No newline at end of file diff --git a/HybridTensor/models/helper.py b/HybridTensor/models/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..f977bc5c1f88b90de1ecfa9613da464d263c1437 --- /dev/null +++ b/HybridTensor/models/helper.py @@ -0,0 +1,125 @@ +import math +import re +from collections import OrderedDict + +from einops import rearrange + + +def remap_state_dict_gpt2(state_dict, config): + # Word embedding and position embedding + def key_mapping_pos_emb(key): + return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) + + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("wte.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict["lm_head.weight"] = state_dict[ + "transformer.embeddings.word_embeddings.weight" + ] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub( + r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + for d in range(config.num_hidden_layers): + W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight") + state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t() + W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight") + state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t() + + def key_mapping_mlp(key): + key = re.sub( + r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key + ) + key = re.sub( + r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key + ) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for d in range(config.num_hidden_layers): + state_dict.pop(f"h.{d}.attn.bias") # We don't store this bias + Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight") + state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() + Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight") + state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t() + + def key_mapping_attn(key): + key = re.sub( + r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key + ) + key = re.sub( + r"^h.(\d+).attn.c_proj.bias", + r"transformer.layers.\1.mixer.out_proj.bias", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def shard_state_dict_tp(state_dict, config, world_size, rank): + """Convert the state_dict of a standard GPT model to the state_dict of a GPT model + with tensor parallel. + """ + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + assert vocab_size % world_size == 0 + assert config.hidden_size % world_size == 0 + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + assert inner_dim % world_size == 0 + + def shard_first_dim(state_dict, key): + x = state_dict[key] + dim = x.shape[0] // world_size + state_dict[key] = x[rank * dim : (rank + 1) * dim] + + def shard_last_dim(state_dict, key): + x = state_dict[key] + dim = x.shape[-1] // world_size + state_dict[key] = x[..., rank * dim : (rank + 1) * dim] + + def shard_qkv_headdim(state_dict, key): + x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3) + dim = x.shape[1] // world_size + state_dict[key] = rearrange( + x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..." + ) + + shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight") + if "lm_head.weight" in state_dict: + shard_first_dim(state_dict, "lm_head.weight") + if "transformer.embeddings.position_embeddings.weight" in state_dict: + shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight") + for i in range(config.num_hidden_layers): + shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") + shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") + shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight") + if rank != 0: + state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias") + shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") + shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") + shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight") + if rank != 0: + state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias") + return state_dict diff --git a/HybridTensor/models/llama.py b/HybridTensor/models/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..2e190297cd7072476f795315dc76dca777e16d93 --- /dev/null +++ b/HybridTensor/models/llama.py @@ -0,0 +1,74 @@ +from transformers import LlamaConfig, LlamaTokenizer + +import torch +import torch.nn as nn +from HybridTensor.models.create_sparse_model import GPTLMHeadModel +from HybridTensor.modules.SelectiveRouters import create_mlp_router_state_dict, create_attn_router_state_dict + +# from flash_attn.models.gpt import GPTLMHeadModel +from transformers import AutoConfig, AutoTokenizer +from flash_attn.utils.pretrained import state_dict_from_pretrained + +from flash_attn.models.llama import ( + config_from_checkpoint, + inv_remap_state_dict_hf_llama, + llama_config_to_gpt2_config, + remap_state_dict_hf_llama, + remap_state_dict_meta_llama, + state_dicts_from_checkpoint, +) + +class SparseConfig: + def __init__(self): + self.mlp_low_rank_dim = 1024 + self.attn_low_rank_dim = 128 + self.mlp_act_th = 0.5 + self.attn_topk = 0.3 + +def build_dense_llama(model_name: str, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None, **kwargs): + config = llama_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, trust_remote_code=True)) + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = False # We don't have fused GatedMLP yet + config.fused_dropout_add_ln = True + config.residual_in_fp32 = True + config.prenorm = True + + state_dict = state_dict_from_pretrained(model_name, device='cpu', dtype=dtype) + state_dict = remap_state_dict_hf_llama(state_dict, config) + + model = GPTLMHeadModel(config, device=device, dtype=dtype) + model.load_state_dict(state_dict, strict=True) + model.eval() + + return model + +def build_sparse_llama(args, model_name: str, attn_ckpt_dir: str, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None, **kwargs): + config = llama_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, trust_remote_code=True)) + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = False # We don't have fused GatedMLP yet + config.fused_dropout_add_ln = True + config.residual_in_fp32 = True + config.prenorm = True + + spconfig = SparseConfig() + spconfig.attn_topk = args.attn_topk + config.mlp_sparse = False + config.att_sparse = True + + state_dict = state_dict_from_pretrained(model_name, device='cpu', dtype=dtype) + state_dict = remap_state_dict_hf_llama(state_dict, config) + + model = GPTLMHeadModel(config, sp_config= spconfig, device=device, dtype=dtype) + + if attn_ckpt_dir is not None: + attn_router_state_dict = create_attn_router_state_dict(attn_ckpt_dir) + merged_state_dict = {**state_dict, **attn_router_state_dict} + + # TODO: Add code for tensor parallel state dict sharding + + model.load_state_dict(merged_state_dict, strict=True) + model.eval() + + return model \ No newline at end of file diff --git a/HybridTensor/models/opt.py b/HybridTensor/models/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..f9856061daa2e05222fc54bdfe0e07d7454899a9 --- /dev/null +++ b/HybridTensor/models/opt.py @@ -0,0 +1,229 @@ +from HybridTensor.utils.activations import OPT_MODELS +import torch +import math +from einops import rearrange + +from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.models.opt import remap_state_dict_hf_opt +from HybridTensor.modules.SelectiveRouters import create_mlp_router_state_dict, create_attn_router_state_dict +from HybridTensor.models.create_sparse_model import GPTLMHeadModel as GPTLMHeadModelSparse +from flash_attn.models.gpt import GPTLMHeadModel + +from transformers.models.opt import OPTConfig +from flash_attn.models.opt import opt_config_to_gpt2_config + +class SparseConfig: + def __init__(self): + self.mlp_low_rank_dim = 1024 + self.attn_low_rank_dim = 128 + self.mlp_act_th = 0.5 + self.attn_topk = 0.3 + +def shard_state_dict_tp(state_dict, config, world_size, rank): + """Convert the state_dict of a standard GPT model to the state_dict of a GPT model + with tensor parallel. + """ + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + assert vocab_size % world_size == 0 + assert config.hidden_size % world_size == 0 + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + assert inner_dim % world_size == 0 + + shared_state_dict = {} + + def shard_first_dim(new, old, key): + x = old[key] + dim = x.shape[0] // world_size + new[key] = x[rank * dim : (rank + 1) * dim] + + def shard_last_dim(new, old, key): + x = old[key] + dim = x.shape[-1] // world_size + new[key] = x[..., rank * dim : (rank + 1) * dim] + + def shard_qkv_headdim(new, old, key): + x = rearrange(old[key], "(three d) ... -> three d ...", three=3) + dim = x.shape[1] // world_size + new[key] = rearrange( + x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..." + ) + + shard_first_dim(shared_state_dict, state_dict, "transformer.embeddings.word_embeddings.weight") + + if "lm_head.weight" in state_dict: + shard_first_dim(shared_state_dict, state_dict, "lm_head.weight") + if "transformer.embeddings.position_embeddings.weight" in state_dict: + shard_last_dim(shared_state_dict, state_dict, "transformer.embeddings.position_embeddings.weight") + + for i in range(config.num_hidden_layers): + # attention + shard_qkv_headdim(shared_state_dict, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") + shard_qkv_headdim(shared_state_dict, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") + shard_last_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight") + + # mlp + shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp.fc1.weight") + shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp.fc1.bias") + shard_last_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp.fc2.weight") + + if rank == 0: + shared_state_dict[f"transformer.layers.{i}.mlp.fc2.bias"] = state_dict[f"transformer.layers.{i}.mlp.fc2.bias"] + shared_state_dict[f"transformer.layers.{i}.mixer.out_proj.bias"] = state_dict[f"transformer.layers.{i}.mixer.out_proj.bias"] + + shared_state_dict[f"transformer.layers.{i}.norm1.weight"] = state_dict[f"transformer.layers.{i}.norm1.weight"] + shared_state_dict[f"transformer.layers.{i}.norm1.bias"] = state_dict[f"transformer.layers.{i}.norm1.bias"] + shared_state_dict[f"transformer.layers.{i}.norm2.weight"] = state_dict[f"transformer.layers.{i}.norm2.weight"] + shared_state_dict[f"transformer.layers.{i}.norm2.bias"] = state_dict[f"transformer.layers.{i}.norm2.bias"] + + # routers + + # mlp router + shared_state_dict[f"transformer.layers.{i}.mlp_router.fc1.weight"] = state_dict[f"transformer.layers.{i}.mlp_router.fc1.weight"] + shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mlp_router.fc2.weight") + + # mha router + shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mha_router.linear1.weight") + shard_first_dim(shared_state_dict, state_dict, f"transformer.layers.{i}.mha_router.linear1.bias") + + shared_state_dict[f"transformer.ln_f.weight"] = state_dict["transformer.ln_f.weight"] + shared_state_dict[f"transformer.ln_f.bias"] = state_dict["transformer.ln_f.bias"] + + # shared_state_dict[f"transformer.ln_f.weight"] = state_dict["transformer.final_layer_norm.weight"] + # shared_state_dict[f"transformer.ln_f.bias"] = state_dict["transformer.final_layer_norm.bias"] + + return shared_state_dict + +''' +def shard_state_dict_tp(state_dict, config, world_size, rank): + """Convert the state_dict of a standard GPT model to the state_dict of a GPT model + with tensor parallel. + """ + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + assert vocab_size % world_size == 0 + assert config.hidden_size % world_size == 0 + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + assert inner_dim % world_size == 0 + + def shard_first_dim(state_dict, key): + x = state_dict[key] + dim = x.shape[0] // world_size + state_dict[key] = x[rank * dim : (rank + 1) * dim] + + def shard_last_dim(state_dict, key): + x = state_dict[key] + dim = x.shape[-1] // world_size + state_dict[key] = x[..., rank * dim : (rank + 1) * dim] + + def shard_qkv_headdim(state_dict, key): + x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3) + dim = x.shape[1] // world_size + state_dict[key] = rearrange( + x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..." + ) + + shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight") + if "lm_head.weight" in state_dict: + shard_first_dim(state_dict, "lm_head.weight") + if "transformer.embeddings.position_embeddings.weight" in state_dict: + shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight") + for i in range(config.num_hidden_layers): + shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") + shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") + shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight") + if rank != 0: + state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias") + shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") + shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") + shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight") + if rank != 0: + state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias") + return state_dict + + +''' + +def build_sparse_opt(args, model_name, mlp_ckpt_dir, attn_ckpt_dir, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None): + # dtype = torch.float16 + + config = OPTConfig.from_pretrained(model_name) + config = opt_config_to_gpt2_config(config) + + if device in ('cpu', torch.device('cpu')): + config.fused_mlp = False + config.fused_dropout_add_ln = False + config.use_flash_attn = False + config.fused_bias_fc = False + else: + config.fused_mlp = True + config.fused_dropout_add_ln = True + config.use_flash_attn = True + config.fused_bias_fc = True + config.sequence_parallel = False + + config.residual_in_fp32 = getattr(config, "prenorm", True) + config.pad_vocab_size_multiple = 8 + config.mlp_sparse = True + config.att_sparse = True + + config.use_heuristic = True + if config.use_heuristic: + print("Using pre-compiled heuristic") + else: + print("Compiling new heuristic during runtime") + + spconfig = SparseConfig() + spconfig.mlp_act_th = 0.5 # sets the threshold for the MLP routers for all layers + spconfig.attn_topk = args.attn_topk # sets the topk for the attention routers for all layers + + # build model + print("Bulding Model with sparse routers") + model_sparse = GPTLMHeadModelSparse(config = config, sp_config = spconfig, process_group = process_group, device = device, dtype=dtype) + # print(model_sparse) + + # load pretrained weights into the sparse model + state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype) + state_dict = remap_state_dict_hf_opt(state_dict, config) + + # load the routers into the model + if mlp_ckpt_dir is not None and attn_ckpt_dir is not None: + mlp_router_state_dict = create_mlp_router_state_dict(mlp_ckpt_dir) + attn_router_state_dict = create_attn_router_state_dict(attn_ckpt_dir) + + # merge the state dict + merged_state_dict = {**state_dict, **mlp_router_state_dict, **attn_router_state_dict} + + if process_group is not None: + merged_state_dict = shard_state_dict_tp(merged_state_dict, config, world_size, rank) + + model_sparse.load_state_dict(merged_state_dict, strict=True) + else: + if process_group is not None: + state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) + model_sparse.load_state_dict(state_dict, strict=False) + + return model_sparse + +def build_dense_opt(model_name, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None): + dtype = torch.float16 + + config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = True + # config.fused_dropout_add_ln = True + config.sequence_parallel = False + # Only prenorm supports residual_in_fp32 + config.residual_in_fp32 = getattr(config, "prenorm", True) + config.pad_vocab_size_multiple = 8 + + # build model + print("Bulding Dense Model") + model = GPTLMHeadModel.from_pretrained(model_name, config, process_group = process_group, world_size = world_size, rank = rank, device=device, dtype=dtype) + + return model \ No newline at end of file diff --git a/HybridTensor/modules/SelectiveBlock.py b/HybridTensor/modules/SelectiveBlock.py new file mode 100644 index 0000000000000000000000000000000000000000..3781246c71337c80df14c6edcb0cfc90d40ac528 --- /dev/null +++ b/HybridTensor/modules/SelectiveBlock.py @@ -0,0 +1,960 @@ +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torchvision.ops import StochasticDepth + +try: + from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm +except ImportError: + layer_norm_fn, RMSNorm = None, None + +class SelectBlock(nn.Module): + def __init__( + self, + dim, + mixer_cls=None, + mlp_cls=None, + mlp_router=None, + mha_router=None, + norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, + prenorm=True, + resid_dropout1=0.0, + resid_dropout2=0.0, + drop_path1=0.0, + drop_path2=0.0, + fused_dropout_add_ln=False, + return_residual=False, + residual_in_fp32=False, + sequence_parallel=False, + mark_shared_params=False, + ): + """ + For prenorm=True, this Block has a slightly different structure compared to a regular + prenorm Transformer block. + + The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. + Here we do: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, etc. + + If you want to do concurrency with CUDA graphs, your shapes must remain fixed + (batch_size, seq_len, etc.) across captures and replays. Also avoid any operations + that cause dynamic shape changes or memory allocations. + """ + super().__init__() + self.prenorm = prenorm + self.fused_dropout_add_ln = fused_dropout_add_ln + self.return_residual = return_residual + self.residual_in_fp32 = residual_in_fp32 + if self.residual_in_fp32: + assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" + + assert mixer_cls is not None and mlp_cls is not None, ( + "mixer_cls and mlp_cls cannot be None in SelectBlock" + ) + + # MHA & MLP submodules + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout1) + self.drop_path1 = StochasticDepth(drop_path1, mode="row") + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + self.total_neurons = self.mlp.fc1.weight.shape[0] + + # Routers + if mlp_router is not None: + self.mlp_router = mlp_router(dim) + self.skip_attn_router = False + else: + self.mlp_router = None + self.skip_attn_router = True + + if mha_router is not None: + self.mha_router = mha_router(dim) + else: + self.mha_router = None + + if not isinstance(self.mlp, nn.Identity): + self.dropout2 = dropout_cls(resid_dropout2) + self.drop_path2 = StochasticDepth(drop_path2, mode="row") + self.norm2 = norm_cls(dim) + + if self.fused_dropout_add_ln: + assert layer_norm_fn is not None, "Triton layer_norm_fn not installed" + assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(self.dropout1, nn.Dropout) + + # Mark the norm parameters for sequence parallel / shared params if needed + if sequence_parallel: + for p in self.norm1.parameters(): + p._sequence_parallel = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._sequence_parallel = True + if mark_shared_params: + for p in self.norm1.parameters(): + p._shared_params = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._shared_params = True + + self.mlp_topk = None + self.skip_mlp_router = False + self.skip_attn_router = False + + # We'll use an extra stream for concurrency + self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0) + self.main_stream = torch.cuda.Stream(device="cuda", priority=-5) + # We'll record events to coordinate concurrency + self.mha_event = torch.cuda.Event(enable_timing=False, blocking=False) + self.mlp_event = torch.cuda.Event(enable_timing=False, blocking=False) + + self.use_tensor_parallel = mark_shared_params + + if self.use_tensor_parallel: + # save the stream and events in the mixer and mlp classes + self.mlp.router = self.mlp_router + self.mixer.router = self.mha_router + + self.mlp_topk_layers = None # this will be a dictionary of layer_idx -> topk value + self.attn_topk_layers = None # this will be a dictionary of layer_idx -> topk value + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def prefill_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_kwargs=None, mixer_subset=None): + hidden_states = self.mixer(hidden_states, **mixer_kwargs) + + if mixer_subset is not None: + residual = residual[:, mixer_subset] + + if not isinstance(self.mlp, nn.Identity): + if not self.fused_dropout_add_ln: + dropped = self.drop_path2(self.dropout2(hidden_states)) + if dropped.shape != residual.shape: + dropped = dropped.view(residual.shape) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + if hidden_states.shape != residual.shape: + hidden_states = hidden_states.view(residual.shape) + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm2.weight, + self.norm2.bias, + residual=residual, + eps=self.norm2.eps, + dropout_p=self.dropout2.p if self.training else 0.0, + rowscale=rowscale2, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm2, RMSNorm), + ) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + def decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None): + """ Single GPU Decode Forward + + Args: + hidden_states (Tensor): _description_ + residual (Optional[Tensor], optional): _description_. Defaults to None. + mixer_subset (_type_, optional): _description_. Defaults to None. + """ + curr_stream = torch.cuda.current_stream() + + # We want to run MHA & mlp_router in parallel on different streams + router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim) + self.main_stream.wait_stream(curr_stream) + self.sparse_stream.wait_stream(curr_stream) + main_stream = self.main_stream + + # if mlp_topk > th * total_neurons, skip mlp router + + # if self.mlp_topk > 0.8 * self.total_neurons: + # self.skip_mlp_router = True + # else: + # self.skip_mlp_router = False + + # [Sparse stream] mlp_router + if not self.skip_mlp_router: + with torch.cuda.stream(self.sparse_stream): + index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk) + self.sparse_stream.record_event(self.mlp_event) + + # [Main stream] MHA + with torch.cuda.stream(main_stream): + batch_head_idx = self.mha_router._select_heads(router_inputs) + hidden_states = self.mixer( + hidden_states, + batch_head_idx=batch_head_idx, + **mixer_kwargs + ) + + main_stream.record_event(self.mha_event) + + # Now we unify after both are done, then do the next steps + with torch.cuda.stream(main_stream): + # Wait on router & MHA + curr_stream.wait_stream(main_stream) + main_stream.wait_event(self.mha_event) + + # normal residual / layernorm + if mixer_subset is not None: + residual = residual[:, mixer_subset] + + if not isinstance(self.mlp, nn.Identity): + if not self.fused_dropout_add_ln: + dropped = self.drop_path2(self.dropout2(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm2( + residual.to(dtype=self.norm2.weight.dtype) + ) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + if hidden_states.shape != residual.shape: + hidden_states = hidden_states.view(residual.shape) + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm2.weight, + self.norm2.bias, + residual=residual, + eps=self.norm2.eps, + dropout_p=self.dropout2.p if self.training else 0.0, + rowscale=rowscale2, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm2, RMSNorm), + ) + + # hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size) + if self.skip_mlp_router: + hidden_states = self.mlp(hidden_states, index_vec=None) + else: + curr_stream.wait_stream(self.sparse_stream) + main_stream.wait_event(self.mlp_event) + hidden_states = self.mlp(hidden_states, index_vec=index_vec) + curr_stream.wait_stream(main_stream) + curr_stream.wait_stream(self.sparse_stream) + + return hidden_states, residual + + def tp_decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None): + """ + Tensor Parallel Decode Forward + + """ + + curr_stream = torch.cuda.current_stream() + self.sparse_stream.wait_stream(curr_stream) + # self.main_stream.wait_stream(curr_stream) + + router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim) + + if self.mlp_topk > 0.8 * self.total_neurons: + self.skip_mlp_router = True + else: + self.skip_mlp_router = False + + # attention router is synchronous + batch_head_idx = self.mha_router._select_heads(router_inputs) + + # mlp router is asynchronous + if not self.skip_mlp_router: + with torch.cuda.stream(self.sparse_stream): + index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk) + self.sparse_stream.record_event(self.mlp_event) + + hidden_states = self.mixer(hidden_states, **mixer_kwargs, batch_head_idx=batch_head_idx) + + if mixer_subset is not None: + residual = residual[:, mixer_subset] + + if not isinstance(self.mlp, nn.Identity): + if not self.fused_dropout_add_ln: + dropped = self.drop_path2(self.dropout2(hidden_states)) + if dropped.shape != residual.shape: + dropped = dropped.view(residual.shape) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + if hidden_states.shape != residual.shape: + hidden_states = hidden_states.view(residual.shape) + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm2.weight, + self.norm2.bias, + residual=residual, + eps=self.norm2.eps, + dropout_p=self.dropout2.p if self.training else 0.0, + rowscale=rowscale2, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm2, RMSNorm), + ) + + # curr_stream.wait_stream(self.sparse_stream) + if self.skip_mlp_router: + hidden_states = self.mlp(hidden_states, index_vec=None) + else: + curr_stream.wait_event(self.mlp_event) + hidden_states = self.mlp(hidden_states, index_vec=index_vec) + + return hidden_states, residual + + def attn_sparse_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None): + """ + Decode Forward with Sparse Attention Router + """ + + # We want to run MHA & mlp_router in parallel on different streams + router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim) + + batch_head_idx = self.mha_router._select_heads(router_inputs) + + # print(f"hidden_states shape: {hidden_states.shape}") + # print(f"hidden states: {hidden_states}") + hidden_states = self.mixer(hidden_states, batch_head_idx=batch_head_idx, **mixer_kwargs) + + # normal residual / layernorm + if mixer_subset is not None: + residual = residual[:, mixer_subset] + + if not isinstance(self.mlp, nn.Identity): + if not self.fused_dropout_add_ln: + dropped = self.drop_path2(self.dropout2(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm2( + residual.to(dtype=self.norm2.weight.dtype) + ) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones(hidden_states.shape[:-1], device=hidden_states.device, dtype=hidden_states.dtype,) + ) + if hidden_states.shape != residual.shape: + hidden_states = hidden_states.view(residual.shape) + hidden_states, residual = layer_norm_fn(hidden_states, self.norm2.weight, self.norm2.bias, residual=residual, + eps=self.norm2.eps, dropout_p=self.dropout2.p if self.training else 0.0, + rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm2, RMSNorm),) + + # hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + def mlp_sparse_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None): + """ Single GPU Decode Forward + + Args: + hidden_states (Tensor): _description_ + residual (Optional[Tensor], optional): _description_. Defaults to None. + mixer_subset (_type_, optional): _description_. Defaults to None. + """ + curr_stream = torch.cuda.current_stream() + + # We want to run MHA & mlp_router in parallel on different streams + router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim) + self.main_stream.wait_stream(curr_stream) + self.sparse_stream.wait_stream(curr_stream) + main_stream = self.main_stream + + # if mlp_topk > th * total_neurons, skip mlp router + + if self.mlp_topk > 0.8 * self.total_neurons: + self.skip_mlp_router = True + else: + self.skip_mlp_router = False + + # [Sparse stream] mlp_router + if not self.skip_mlp_router: + with torch.cuda.stream(self.sparse_stream): + index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk) + self.sparse_stream.record_event(self.mlp_event) + + # [Main stream] MHA + with torch.cuda.stream(main_stream): + # batch_head_idx = self.mha_router._select_heads(router_inputs) + hidden_states = self.mixer( + hidden_states, + batch_head_idx=None, + **mixer_kwargs + ) + + main_stream.record_event(self.mha_event) + + # Now we unify after both are done, then do the next steps + with torch.cuda.stream(main_stream): + # Wait on router & MHA + curr_stream.wait_stream(main_stream) + main_stream.wait_event(self.mha_event) + + # normal residual / layernorm + if mixer_subset is not None: + residual = residual[:, mixer_subset] + + if not isinstance(self.mlp, nn.Identity): + if not self.fused_dropout_add_ln: + dropped = self.drop_path2(self.dropout2(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm2( + residual.to(dtype=self.norm2.weight.dtype) + ) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + if hidden_states.shape != residual.shape: + hidden_states = hidden_states.view(residual.shape) + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm2.weight, + self.norm2.bias, + residual=residual, + eps=self.norm2.eps, + dropout_p=self.dropout2.p if self.training else 0.0, + rowscale=rowscale2, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm2, RMSNorm), + ) + + # hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size) + if self.skip_mlp_router: + hidden_states = self.mlp(hidden_states, index_vec=None) + else: + curr_stream.wait_stream(self.sparse_stream) + main_stream.wait_event(self.mlp_event) + hidden_states = self.mlp(hidden_states, index_vec=index_vec) + curr_stream.wait_stream(main_stream) + curr_stream.wait_stream(self.sparse_stream) + + return hidden_states, residual + + def forward( + self, + hidden_states: Tensor, + residual: Optional[Tensor] = None, + mixer_subset=None, + mixer_kwargs=None, + mlp_topk=None, + attn_topk=None, + ): + """ + This forward pass includes concurrency logic in the decode branch. + If you're capturing with a CUDA graph, the concurrency (two-stream usage) must be + inside the captured region so that the replay reproduces the parallel streams. + """ + + # simulation values + if mlp_topk is not None: + self.mlp_topk = mlp_topk + + if attn_topk is not None: + self.mha_router.topk = attn_topk + + if mixer_kwargs is None: + mixer_kwargs = {"inference_params": None} + else: + # Ensure 'inference_params' key exists + if "inference_params" not in mixer_kwargs: + mixer_kwargs["inference_params"] = None + + if self.prenorm: + # --- 1) Prenorm’s dropout/add/layernorm + if not self.fused_dropout_add_ln: + dropped = self.drop_path1(self.dropout1(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + # fused dropout + add + layernorm + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + if residual is not None and hidden_states.shape != residual.shape: + hidden_states = hidden_states.view(residual.shape) + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm1.weight, + self.norm1.bias, + residual=residual, + eps=self.norm1.eps, + dropout_p=self.dropout1.p if self.training else 0.0, + rowscale=rowscale1, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm1, RMSNorm), + ) + + if mixer_subset is not None: + mixer_kwargs["mixer_subset"] = mixer_subset + + # Check if we are in the prefill or decode stage + prefill_stage = ( + mixer_kwargs["inference_params"] is None + or mixer_kwargs["inference_params"].seqlen_offset == 0 + ) + + if prefill_stage: + # --- 2) Prefill stage (no concurrency): just do normal forward + hidden_states, residual = self.prefill_forward(hidden_states, residual, mixer_kwargs, mixer_subset) + + else: + # --- 3) Decode stage: + if self.mlp_router is None: + # decode stage with only attention router, works with both single gpu and tensor parallel + hidden_states, residual = self.attn_sparse_forward(hidden_states, residual, mixer_subset, mixer_kwargs) + else: + if not self.use_tensor_parallel: + if self.mha_router is None: + # decode stage with mlp routers (opt models and single gpu) + hidden_states, residual = self.mlp_sparse_forward(hidden_states, residual, mixer_subset, mixer_kwargs) + else: + # decode stage with mlp and attention routers (opt models and single gpu) + hidden_states, residual = self.decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs) + else: + # uses both mlp and attention routers in tensor parallel + hidden_states, residual = self.tp_decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs) + + return hidden_states, residual + + else: + # post-norm architecture not implemented here + raise NotImplementedError + + +# class SelectBlock(nn.Module): +# def __init__( +# self, +# dim, +# mixer_cls=None, +# mlp_cls=None, +# mlp_router=None, +# mha_router=None, +# norm_cls=nn.LayerNorm, +# dropout_cls=nn.Dropout, +# prenorm=True, +# resid_dropout1=0.0, +# resid_dropout2=0.0, +# drop_path1=0.0, +# drop_path2=0.0, +# fused_dropout_add_ln=False, +# return_residual=False, +# residual_in_fp32=False, +# sequence_parallel=False, +# mark_shared_params=False, +# ): +# """ +# For prenorm=True, this Block has a slightly different structure compared to a regular +# prenorm Transformer block. + +# The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. +# Here we do: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, etc. + +# If you want to do concurrency with CUDA graphs, your shapes must remain fixed +# (batch_size, seq_len, etc.) across captures and replays. Also avoid any operations +# that cause dynamic shape changes or memory allocations. +# """ +# super().__init__() +# self.prenorm = prenorm +# self.fused_dropout_add_ln = fused_dropout_add_ln +# self.return_residual = return_residual +# self.residual_in_fp32 = residual_in_fp32 +# if self.residual_in_fp32: +# assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" + +# assert mixer_cls is not None and mlp_cls is not None, ( +# "mixer_cls and mlp_cls cannot be None in SelectBlock" +# ) + +# # MHA & MLP submodules +# self.mixer = mixer_cls(dim) +# self.dropout1 = dropout_cls(resid_dropout1) +# self.drop_path1 = StochasticDepth(drop_path1, mode="row") +# self.norm1 = norm_cls(dim) +# self.mlp = mlp_cls(dim) + +# # Routers +# self.mlp_router = mlp_router(dim) +# self.mha_router = mha_router(dim) + +# if not isinstance(self.mlp, nn.Identity): +# self.dropout2 = dropout_cls(resid_dropout2) +# self.drop_path2 = StochasticDepth(drop_path2, mode="row") +# self.norm2 = norm_cls(dim) + +# if self.fused_dropout_add_ln: +# assert layer_norm_fn is not None, "Triton layer_norm_fn not installed" +# assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(self.dropout1, nn.Dropout) + +# # Mark the norm parameters for sequence parallel / shared params if needed +# if sequence_parallel: +# for p in self.norm1.parameters(): +# p._sequence_parallel = True +# if hasattr(self, "norm2"): +# for p in self.norm2.parameters(): +# p._sequence_parallel = True +# if mark_shared_params: +# for p in self.norm1.parameters(): +# p._shared_params = True +# if hasattr(self, "norm2"): +# for p in self.norm2.parameters(): +# p._shared_params = True + +# self.mlp_topk = None +# self.skip_mlp_router = False +# self.skip_attn_router = False + +# # We'll use an extra stream for concurrency +# self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0) +# self.main_stream = torch.cuda.Stream(device="cuda", priority=-5) +# # We'll record events to coordinate concurrency +# self.mha_event = torch.cuda.Event(enable_timing=False, blocking=False) +# self.mlp_event = torch.cuda.Event(enable_timing=False, blocking=False) + +# self.use_tensor_parallel = mark_shared_params + +# if self.use_tensor_parallel: +# # TODO: save the routers in the mixer and mlp classes +# # save the stream and events in the mixer and mlp classes +# pass + +# def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): +# return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + +# def prefill_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_kwargs=None, mixer_subset=None): +# hidden_states = self.mixer(hidden_states, **mixer_kwargs) + +# if mixer_subset is not None: +# residual = residual[:, mixer_subset] + +# if not isinstance(self.mlp, nn.Identity): +# if not self.fused_dropout_add_ln: +# dropped = self.drop_path2(self.dropout2(hidden_states)) +# if dropped.shape != residual.shape: +# dropped = dropped.view(residual.shape) +# residual = (dropped + residual) if residual is not None else dropped +# hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) +# if self.residual_in_fp32: +# residual = residual.to(torch.float32) +# else: +# if self.drop_path2.p == 0 or not self.training: +# rowscale2 = None +# else: +# rowscale2 = self.drop_path2( +# torch.ones( +# hidden_states.shape[:-1], +# device=hidden_states.device, +# dtype=hidden_states.dtype, +# ) +# ) +# if hidden_states.shape != residual.shape: +# hidden_states = hidden_states.view(residual.shape) +# hidden_states, residual = layer_norm_fn( +# hidden_states, +# self.norm2.weight, +# self.norm2.bias, +# residual=residual, +# eps=self.norm2.eps, +# dropout_p=self.dropout2.p if self.training else 0.0, +# rowscale=rowscale2, +# prenorm=True, +# residual_in_fp32=self.residual_in_fp32, +# is_rms_norm=isinstance(self.norm2, RMSNorm), +# ) +# hidden_states = self.mlp(hidden_states) +# return hidden_states, residual + +# def decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None): +# """ Single GPU Decode Forward + +# Args: +# hidden_states (Tensor): _description_ +# residual (Optional[Tensor], optional): _description_. Defaults to None. +# mixer_subset (_type_, optional): _description_. Defaults to None. +# """ +# curr_stream = torch.cuda.current_stream() + +# # We want to run MHA & mlp_router in parallel on different streams +# router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim) +# self.main_stream.wait_stream(curr_stream) +# self.sparse_stream.wait_stream(curr_stream) + +# # We'll do MHA on the "main_stream" and the router on "sparse_stream" +# main_stream = self.main_stream +# # In a captured region, each 'with torch.cuda.stream(...)' block +# # is replayed in concurrency. The shape must remain consistent. + +# # [Sparse stream] mlp_router +# if not self.skip_mlp_router: +# with torch.cuda.stream(self.sparse_stream): # <-- CHANGED +# # index_size, index_vec = self.mlp_router._select_neurons_cuda_safe(router_inputs) # need to fix this; make CUDA Graph safe +# # vec = self.mlp_router(router_inputs) +# index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk) +# self.sparse_stream.record_event(self.mlp_event) + +# # [Main stream] MHA +# with torch.cuda.stream(main_stream): # <-- CHANGED +# batch_head_idx = self.mha_router._select_heads(router_inputs) +# hidden_states = self.mixer( +# hidden_states, +# batch_head_idx=batch_head_idx, +# # batch_head_idx=None, +# **mixer_kwargs +# ) +# main_stream.record_event(self.mha_event) + +# # Now we unify after both are done, then do the next steps +# with torch.cuda.stream(main_stream): # <-- CHANGED +# # Wait on router & MHA +# curr_stream.wait_stream(main_stream) +# main_stream.wait_event(self.mha_event) + +# # normal residual / layernorm +# if mixer_subset is not None: +# residual = residual[:, mixer_subset] + +# if not isinstance(self.mlp, nn.Identity): +# if not self.fused_dropout_add_ln: +# dropped = self.drop_path2(self.dropout2(hidden_states)) +# residual = (dropped + residual) if residual is not None else dropped +# hidden_states = self.norm2( +# residual.to(dtype=self.norm2.weight.dtype) +# ) +# if self.residual_in_fp32: +# residual = residual.to(torch.float32) +# else: +# if self.drop_path2.p == 0 or not self.training: +# rowscale2 = None +# else: +# rowscale2 = self.drop_path2( +# torch.ones( +# hidden_states.shape[:-1], +# device=hidden_states.device, +# dtype=hidden_states.dtype, +# ) +# ) +# if hidden_states.shape != residual.shape: +# hidden_states = hidden_states.view(residual.shape) +# hidden_states, residual = layer_norm_fn( +# hidden_states, +# self.norm2.weight, +# self.norm2.bias, +# residual=residual, +# eps=self.norm2.eps, +# dropout_p=self.dropout2.p if self.training else 0.0, +# rowscale=rowscale2, +# prenorm=True, +# residual_in_fp32=self.residual_in_fp32, +# is_rms_norm=isinstance(self.norm2, RMSNorm), +# ) + +# # Finally do MLP with the router's index vector +# curr_stream.wait_stream(self.sparse_stream) +# main_stream.wait_event(self.mlp_event) + +# # hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size) +# if self.skip_mlp_router: +# hidden_states = self.mlp(hidden_states, index_vec=None) +# else: +# hidden_states = self.mlp(hidden_states, index_vec=index_vec) +# curr_stream.wait_stream(main_stream) +# curr_stream.wait_stream(self.sparse_stream) + +# return hidden_states, residual + +# def tp_decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None): +# """ +# Tensor Parallel Decode Forward + +# Args: +# hidden_states (Tensor): _description_ +# residual (Optional[Tensor], optional): _description_. Defaults to None. +# mixer_subset (_type_, optional): _description_. Defaults to None. +# """ +# # TODO: need to add routing + +# hidden_states = self.mixer(hidden_states, **mixer_kwargs) + +# if mixer_subset is not None: +# residual = residual[:, mixer_subset] + +# if not isinstance(self.mlp, nn.Identity): +# if not self.fused_dropout_add_ln: +# dropped = self.drop_path2(self.dropout2(hidden_states)) +# if dropped.shape != residual.shape: +# dropped = dropped.view(residual.shape) +# residual = (dropped + residual) if residual is not None else dropped +# hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) +# if self.residual_in_fp32: +# residual = residual.to(torch.float32) +# else: +# if self.drop_path2.p == 0 or not self.training: +# rowscale2 = None +# else: +# rowscale2 = self.drop_path2( +# torch.ones( +# hidden_states.shape[:-1], +# device=hidden_states.device, +# dtype=hidden_states.dtype, +# ) +# ) +# if hidden_states.shape != residual.shape: +# hidden_states = hidden_states.view(residual.shape) +# hidden_states, residual = layer_norm_fn( +# hidden_states, +# self.norm2.weight, +# self.norm2.bias, +# residual=residual, +# eps=self.norm2.eps, +# dropout_p=self.dropout2.p if self.training else 0.0, +# rowscale=rowscale2, +# prenorm=True, +# residual_in_fp32=self.residual_in_fp32, +# is_rms_norm=isinstance(self.norm2, RMSNorm), +# ) +# hidden_states = self.mlp(hidden_states) +# return hidden_states, residual + +# def forward( +# self, +# hidden_states: Tensor, +# residual: Optional[Tensor] = None, +# mixer_subset=None, +# mixer_kwargs=None, +# ): +# """ +# This forward pass includes concurrency logic in the decode branch. +# If you're capturing with a CUDA graph, the concurrency (two-stream usage) must be +# inside the captured region so that the replay reproduces the parallel streams. +# """ + + +# if mixer_kwargs is None: +# mixer_kwargs = {"inference_params": None} +# else: +# # Ensure 'inference_params' key exists +# if "inference_params" not in mixer_kwargs: +# mixer_kwargs["inference_params"] = None + +# if self.prenorm: +# # --- 1) Prenorm’s dropout/add/layernorm +# if not self.fused_dropout_add_ln: +# dropped = self.drop_path1(self.dropout1(hidden_states)) +# residual = (dropped + residual) if residual is not None else dropped +# hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) +# if self.residual_in_fp32: +# residual = residual.to(torch.float32) +# else: +# # fused dropout + add + layernorm +# if self.drop_path1.p == 0 or not self.training: +# rowscale1 = None +# else: +# rowscale1 = self.drop_path1( +# torch.ones( +# hidden_states.shape[:-1], +# device=hidden_states.device, +# dtype=hidden_states.dtype, +# ) +# ) +# if residual is not None and hidden_states.shape != residual.shape: +# hidden_states = hidden_states.view(residual.shape) +# hidden_states, residual = layer_norm_fn( +# hidden_states, +# self.norm1.weight, +# self.norm1.bias, +# residual=residual, +# eps=self.norm1.eps, +# dropout_p=self.dropout1.p if self.training else 0.0, +# rowscale=rowscale1, +# prenorm=True, +# residual_in_fp32=self.residual_in_fp32, +# is_rms_norm=isinstance(self.norm1, RMSNorm), +# ) + +# if mixer_subset is not None: +# mixer_kwargs["mixer_subset"] = mixer_subset + +# # Check if we are in the prefill or decode stage +# prefill_stage = ( +# mixer_kwargs["inference_params"] is None +# or mixer_kwargs["inference_params"].seqlen_offset == 0 +# ) + +# if prefill_stage: +# # --- 2) Prefill stage (no concurrency): just do normal forward +# hidden_states, residual = self.prefill_forward(hidden_states, residual, mixer_kwargs, mixer_subset) + +# else: +# # # --- 3) Decode stage: +# if not self.use_tensor_parallel: +# hidden_states, residual = self.decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs) +# else: +# # routing is slightly different in tensor parallel; we overlap the router with allreduce +# hidden_states, residual = self.tp_decode_forward(hidden_states, residual, mixer_subset) + +# return hidden_states, residual + +# else: +# # post-norm architecture not implemented here +# raise NotImplementedError diff --git a/HybridTensor/modules/SelectiveMHA.py b/HybridTensor/modules/SelectiveMHA.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d900812399031067c6e6a4769b01c3c4f32b4a --- /dev/null +++ b/HybridTensor/modules/SelectiveMHA.py @@ -0,0 +1,1579 @@ +import math +from functools import partial + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from flash_attn.utils.distributed import get_dim_for_local_rank +from flash_attn.utils.distributed import all_reduce + +try: + from flash_attn import ( + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_with_kvcache, + ) +except ImportError: + flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None + flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None + flash_attn_with_kvcache = None + +try: + from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear, fused_dense_func +except ImportError: + FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None + +try: + from flash_attn.layers.rotary import RotaryEmbedding +except ImportError: + RotaryEmbedding = None + +from flash_attn.modules.mha import SelfAttention, FlashSelfAttention, LinearResidual, FlashCrossAttention, CrossAttention +from flash_attn.modules.mha import get_alibi_slopes #, _update_kv_cache +from flash_attn.utils.generation import InferenceParams + +# from HybridTensor.modules.references.mha_dejavu import ParallelTracker # use this in the full implementation +# from HybridTensor.modules.references.mha_dejavu import ParallelMHASparseAttMlp +# from HybridTensor.triton.references.attention_proj_sparse import qkv_proj_sparse +# from HybridTensor.triton.select_attn import select_attn +# from HybridTensor.triton.select_attn_64b_kernel import select_attn +from HybridTensor.triton.attn_interface import flash_attn_with_kvcache_triton +from HybridTensor.triton.select_attn_v1 import select_attn +from HybridTensor.utils.utils import arg_parser, generate_BH_index, generate_random_BH_index +from HybridTensor.utils.profiling import cuda_profiler + + +class MHARouter(torch.nn.Module): + def __init__(self, embed_dim, low_rank_dim = None, out_dim = None, top_k = 0.5, device = None, dtype = None): + super(MHARouter, self).__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.model_dim = embed_dim + self.num_heads = out_dim + self.topk = top_k + + self.linear1 = torch.nn.Linear(embed_dim, out_dim, bias = True, **factory_kwargs) + + def forward(self, x): + out = self.linear1(x) + return out + + def _select_heads(self, x, topk = None): + if topk is None: + topk = int(self.topk * self.num_heads) + else: + topk = int(self.num_heads * topk) + head_scores = self.forward(x) + _, selected_heads = torch.topk(head_scores, topk, dim=1) + + return selected_heads + +class ParallelMHARouter(torch.nn.Module): + def __init__(self, embed_dim, low_rank_dim, out_dim, top_k, process_group, sequence_parallel=False, device = None, dtype = None): + super(ParallelMHARouter, self).__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.model_dim = embed_dim + self.num_heads = out_dim + self.topk = top_k + world_size = torch.distributed.get_world_size(process_group) + self.local_heads = out_dim // world_size + + self.linear1 = ColumnParallelLinear( + embed_dim, + out_dim, + process_group, + bias=True, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + + def forward(self, x): + out = self.linear1(x) + return out + + def _select_heads(self, x, topk = None): + if topk is None: + topk = int(self.topk * self.local_heads) + else: + topk = int(self.local_heads * topk) + head_scores = self.forward(x) + # head_scores = head_scores.squeeze(1) + # print(f"Head Scores.shape: {head_scores.shape}") + _, selected_heads = torch.topk(head_scores, topk, dim=1) + # print(f"Selected Heads: {selected_heads}") + return selected_heads + +def _update_kv_cache(kv, inference_params, layer_idx): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + # Pre-allocate memory for key-values for inference. + num_heads, head_dim = kv.shape[-2:] + if layer_idx not in inference_params.key_value_memory_dict: + kv_cache = torch.empty( + inference_params.max_batch_size, + inference_params.max_seqlen, + 2, + num_heads, + head_dim, + dtype=kv.dtype, + device=kv.device, + ) + inference_params.key_value_memory_dict[layer_idx] = kv_cache + else: + kv_cache = inference_params.key_value_memory_dict[layer_idx] + # Adjust key and value for inference + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + sequence_start = inference_params.seqlen_offset + sequence_end = sequence_start + kv.shape[1] + assert batch_end <= kv_cache.shape[0] + assert sequence_end <= kv_cache.shape[1] + assert kv_cache is not None + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + return kv_cache[batch_start:batch_end, :sequence_end, ...] + +class SMHA(nn.Module): + """Multi-head self-attention and cross-attention with Triton decode kernels + Selective Attention""" + + def __init__( + self, + embed_dim, + num_heads, + num_heads_kv=None, + cross_attn=False, + qkv_proj_bias=True, + out_proj_bias=True, + dropout=0.0, + softmax_scale=None, + causal=False, + layer_idx=None, + dwconv=False, + rotary_emb_dim=0, + rotary_emb_base=10000.0, + rotary_emb_scale_base=None, + rotary_emb_interleaved=False, + use_alibi=False, + window_size=(-1, -1), + fused_bias_fc=False, + use_flash_attn=False, + return_residual=False, + checkpointing=False, + use_triton=True, + device=None, + dtype=None, + ) -> None: + """ + num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.cross_attn = cross_attn + self.causal = causal + self.layer_idx = layer_idx + self.dwconv = dwconv + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.return_residual = return_residual + self.checkpointing = checkpointing + self.use_triton = use_triton + if use_alibi: + assert use_flash_attn, "ALiBi code path requires flash_attn" + alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device) + else: + alibi_slopes = None + if window_size != (-1, -1): + assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" + + self.num_heads = num_heads + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads + assert ( + self.num_heads % self.num_heads_kv == 0 + ), "num_heads must be divisible by num_heads_kv" + assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + kv_dim = 2 * self.head_dim * self.num_heads_kv + + if self.rotary_emb_dim > 0: + assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet" + assert RotaryEmbedding is not None, "rotary_emb is not installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, + base=rotary_emb_base, + scale_base=rotary_emb_scale_base, + interleaved=rotary_emb_interleaved, + device=device, + ) + + if fused_bias_fc and FusedDense is None: + raise ImportError("fused_dense is not installed") + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + linear_resid_cls = ( + LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) + ) + wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls + inner_attn_cls = ( + partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) + if use_flash_attn + else SelfAttention + ) + inner_cross_attn_cls = ( + partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) + if use_flash_attn + else CrossAttention + ) + if not self.cross_attn: + self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) + else: + self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) + if self.dwconv: + if self.num_heads_kv == self.num_heads: + self.dwconv_qkv = nn.Conv1d( + qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim + ) + else: + self.dwconv_q = nn.Conv1d( + embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim + ) + self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim) + self.inner_attn = inner_attn_cls( + causal=causal, + softmax_scale=softmax_scale, + attention_dropout=dropout, + ) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + return torch.empty( + batch_size, + max_seqlen, + 2, + self.num_heads_kv, + self.head_dim, + dtype=dtype, + device=device, + ) + + def _update_kv_cache(self, kv, inference_params): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + assert not self.dwconv, "Generation does not support dwconv yet" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): + """ + Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. + q: (batch_size, seqlen_q, nheads, head_dim) + kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) + """ + assert inference_params is not None and inference_params.seqlen_offset > 0 + assert self.use_flash_attn + if self.rotary_emb_dim > 0: + assert self.rotary_emb.scale is None, "This code path does not support xPos" + self.rotary_emb._update_cos_sin_cache( + inference_params.max_seqlen, device=q.device, dtype=q.dtype + ) + rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached + else: + rotary_cos, rotary_sin = None, None + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) + + context = flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, + alibi_slopes=alibi_slopes, + ) + return context + + + def _update_kvcache_attention_triton(self, q, kv, inference_params, batch_head_idx=None): + """ + The rotary embeddings have to be applied before calling this function. The KV cache is update here. + q: (batch_size, seqlen_q, nheads, head_dim) + kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) + """ + if ( + inference_params.seqlen_offset == 0 + or flash_attn_with_kvcache is None + or not self.use_flash_attn + ): + # TODO: this only uses seqlen_offset and not lengths_per_sample. + kv = self._update_kv_cache(kv, inference_params) + return self.inner_cross_attn(q, kv) + else: + batch = q.shape[0] + kv_cache = self._update_kv_cache(kv, inference_params) + + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) + + context = flash_attn_with_kvcache_triton( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + None, # kv[:, :, 0], + None, #kv[:, :, 1], + rotary_cos=None, + rotary_sin=None, + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + rotary_interleaved= False, #self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, + alibi_slopes=alibi_slopes, + batch_head_idx=batch_head_idx, + ) + return context + + def _update_kvcache_attention(self, q, kv, inference_params): + """Write kv to inference_params, then do attention""" + if ( + inference_params.seqlen_offset == 0 + or flash_attn_with_kvcache is None + or not self.use_flash_attn + ): + # TODO: this only uses seqlen_offset and not lengths_per_sample. + kv = self._update_kv_cache(kv, inference_params) + return self.inner_cross_attn(q, kv) + else: + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) + return flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + alibi_slopes=alibi_slopes, + ) + + def forward( + self, + x, + x_kv=None, + key_padding_mask=None, + cu_seqlens=None, + max_seqlen=None, + mixer_subset=None, + inference_params=None, + batch_head_idx=None, + use_triton=True, + **kwargs, + ): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if + cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total + is the is the sum of the sequence lengths in the batch. + x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into x. Only applicable when using + FlashAttention. + max_seqlen: int. Maximum sequence length in the batch. + key_padding_mask: boolean mask, True means to keep, False means to mask out. + (batch, seqlen). Only applicable when not using FlashAttention. + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + inference_params: for generation. Adapted from Megatron-LM (and Apex) + batch_head_idx: (batch, num_heads). The index of the heads to be selected. Only applicable for Selective Head/Group Attention. + use_triton: whether to use triton kernels for attention in decode. If False, use the original flash attention implementation. + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + """ + if cu_seqlens is not None: + assert max_seqlen is not None + assert key_padding_mask is None + assert self.use_flash_attn + assert not self.dwconv + assert self.rotary_emb_dim == 0 + if key_padding_mask is not None: + assert cu_seqlens is None + assert max_seqlen is None + assert not self.use_flash_attn + if inference_params is not None: + assert key_padding_mask is None + assert cu_seqlens is None and max_seqlen is None + assert not self.dwconv + # use_triton = self.use_triton if use_triton is None else use_triton + kwargs = ( + {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs} + if self.use_flash_attn + else {"key_padding_mask": key_padding_mask, **kwargs} + ) + seqlen_offset = ( + 0 + if inference_params is None + else ( + inference_params.lengths_per_sample + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None + batch, seqlen = x.shape[:2] + if not self.cross_attn and self.num_heads_kv == self.num_heads: + assert x_kv is None and mixer_subset is None + if not self.return_residual: + qkv = self.Wqkv(x) + else: + qkv, x = self.Wqkv(x) + if self.dwconv: + qkv = rearrange( + self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" + ).contiguous() + qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) + if ( + inference_params is None + or inference_params.seqlen_offset == 0 + or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) + or not self.use_flash_attn + ): + # prefill stage + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + if inference_params is None: + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) + else: + if use_triton: + # print("Using the (prefill) triton flash attention implementation") + context = self._update_kvcache_attention_triton( + qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx + ) + else: + # print("Using the (prefill) original flash attention implementation") + context = self._update_kvcache_attention( + qkv[:, :, 0], qkv[:, :, 1:], inference_params + ) + else: + # decode stage + # print("Using triton kernels for attention") + if use_triton: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + context = self._update_kvcache_attention_triton( + qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx + ) + else: + # print("Using the original flash attention implementation") + context = self._apply_rotary_update_kvcache_attention( + qkv[:, :, 0], qkv[:, :, 1:], inference_params + ) + + else: # cross-attention or MQA/GQA + if self.cross_attn: + if not self.return_residual: + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + kv = self.Wkv(x_kv if x_kv is not None else x) + else: + if x_kv is not None: + kv, x_kv = self.Wkv(x_kv) + else: + kv, x = self.Wkv(x) + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + else: + assert self.num_heads_kv != self.num_heads + if not self.return_residual: + qkv = self.Wqkv(x) + else: + qkv, x = self.Wqkv(x) + q = qkv[..., : self.num_heads * self.head_dim] + kv = qkv[..., self.num_heads * self.head_dim :] + q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) + kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) + if self.dwconv: + q = rearrange( + self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" + ).contiguous() + kv = rearrange( + self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" + ).contiguous() + if ( + inference_params is None + or inference_params.seqlen_offset == 0 + or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) + or not self.use_flash_attn + ): + # prefill + if self.rotary_emb_dim > 0: + q, kv = self.rotary_emb( + q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + if inference_params is None: + if not self.checkpointing: + context = self.inner_cross_attn(q, kv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint( + self.inner_cross_attn, q, kv, **kwargs + ) + else: + if use_triton: + context = self._update_kvcache_attention_triton( + q, kv, inference_params, batch_head_idx + ) + else: + context = self._update_kvcache_attention(q, kv, inference_params) + else: + # decode + # print("Using triton kernels for attention") + if use_triton: + if self.rotary_emb_dim > 0: + q, kv = self.rotary_emb( + q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + context = self._update_kvcache_attention_triton( + q, kv, inference_params, batch_head_idx + ) + else: + # print("Using the original gqa flash attention implementation") + context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) + # print(f"Context.shape: {context.shape}") + # print(f"Context: {context}") + out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) + return out if not self.return_residual else (out, x) + +class ParallelSMHA(nn.Module): + """Multi-head self-attention and cross-attention""" + + def __init__( + self, + embed_dim, + num_heads, + process_group, + num_heads_kv=None, + qkv_proj_bias=True, + out_proj_bias=True, + dropout=0.0, + softmax_scale=None, + causal=False, + layer_idx=None, + rotary_emb_dim=0, + rotary_emb_base=10000.0, + rotary_emb_scale_base=None, + rotary_emb_interleaved=False, + use_alibi=False, + window_size=(-1, -1), + use_flash_attn=False, + checkpointing=False, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.causal = causal + self.layer_idx = layer_idx + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.checkpointing = checkpointing + self.process_group = process_group + self.world_size = process_group.size() + self.local_rank = torch.distributed.get_rank(process_group) + + self.num_heads = num_heads + assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" + + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads + assert ( + self.num_heads % self.num_heads_kv == 0 + ), "num_heads must be divisible by num_heads_kv" + + self.num_heads_per_rank = get_dim_for_local_rank( + self.num_heads, self.world_size, self.local_rank + ) + self.num_heads_kv_per_rank = get_dim_for_local_rank( + self.num_heads_kv, self.world_size, self.local_rank + ) + self.head_dim = self.embed_dim // num_heads + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + + if use_alibi: + assert use_flash_attn, "ALiBi code path requires flash_attn" + num_heads_local = math.ceil(self.num_heads / self.world_size) + alibi_slopes = torch.tensor( + get_alibi_slopes(num_heads)[ + self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local + ], + device=device, + ) + else: + alibi_slopes = None + if window_size != (-1, -1): + assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" + + if self.rotary_emb_dim > 0: + assert RotaryEmbedding is not None, "rotary_emb is not installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, + base=rotary_emb_base, + scale_base=rotary_emb_scale_base, + interleaved=rotary_emb_interleaved, + device=device, + ) + + if ColumnParallelLinear is None or RowParallelLinear is None: + raise ImportError("fused_dense is not installed") + self.Wqkv = ColumnParallelLinear( + embed_dim, + qkv_dim, + process_group, + bias=qkv_proj_bias, + sequence_parallel=sequence_parallel, + multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2), + **factory_kwargs, + ) + inner_attn_cls = ( + partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) + if use_flash_attn + else SelfAttention + ) + inner_cross_attn_cls = ( + partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) + if use_flash_attn + else CrossAttention + ) + self.inner_attn = inner_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + process_group, + bias=out_proj_bias, + sequence_parallel=sequence_parallel, + multiple_of=self.head_dim, + **factory_kwargs, + ) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + return torch.empty( + batch_size, + max_seqlen, + 2, + self.num_heads_kv_per_rank, + self.head_dim, + dtype=dtype, + device=device, + ) + + def _update_kv_cache(self, kv, inference_params): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): + """ + Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. + q: (batch_size, seqlen_q, nheads, head_dim) + kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) + """ + assert inference_params is not None and inference_params.seqlen_offset > 0 + assert self.use_flash_attn + if self.rotary_emb_dim > 0: + assert self.rotary_emb.scale is None, "This code path does not support xPos" + self.rotary_emb._update_cos_sin_cache( + inference_params.max_seqlen, device=q.device, dtype=q.dtype + ) + rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached + else: + rotary_cos, rotary_sin = None, None + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) + context = flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, + alibi_slopes=alibi_slopes, + ) + return context + + def _update_kvcache_attention(self, q, kv, inference_params): + """Write kv to inference_params, then do attention""" + if inference_params.seqlen_offset == 0 or not self.use_flash_attn: + # TODO: this only uses seqlen_offset and not lengths_per_sample. + kv = self._update_kv_cache(kv, inference_params) + return self.inner_cross_attn(q, kv) + else: + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) + context = flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + alibi_slopes=alibi_slopes, + ) + return context + + def forward(self, x, seqlen=None, inference_params=None, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + qkv = self.Wqkv(x) + if seqlen is not None: + qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) + seqlen_offset = ( + 0 + if inference_params is None + else ( + inference_params.lengths_per_sample + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None + if self.num_heads_kv == self.num_heads: + qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) + if ( + inference_params is None + or inference_params.seqlen_offset == 0 + or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) + or not self.use_flash_attn + ): + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + if inference_params is None: + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) + else: + context = self._update_kvcache_attention( + qkv[:, :, 0], qkv[:, :, 1:], inference_params + ) + else: + context = self._apply_rotary_update_kvcache_attention( + qkv[:, :, 0], qkv[:, :, 1:], inference_params + ) + else: # GQA/MQA + q = rearrange( + qkv[..., : self.num_heads_per_rank * self.head_dim], + "... (h d) -> ... h d", + d=self.head_dim, + ) + kv = rearrange( + qkv[..., self.num_heads_per_rank * self.head_dim :], + "... (two hkv d) -> ... two hkv d", + two=2, + d=self.head_dim, + ) + if ( + inference_params is None + or inference_params.seqlen_offset == 0 + or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) + or not self.use_flash_attn + ): + if self.rotary_emb_dim > 0: + q, kv = self.rotary_emb( + q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + if inference_params is None: + if not self.checkpointing: + context = self.inner_cross_attn(q, kv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint( + self.inner_cross_attn, q, kv, **kwargs + ) + else: + context = self._update_kvcache_attention(q, kv, inference_params) + else: + context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) + context = rearrange(context, "b s h d -> b s (h d)") + if seqlen is not None: + context = rearrange(context, "b s d -> (b s) d") + out = self.out_proj(context) + return out + +class SelectMHA(nn.Module): + """Multi-head, Group-query self-attention using select attention""" + + def __init__( + self, + embed_dim, + num_heads, + num_heads_kv=None, + cross_attn=False, + qkv_proj_bias=True, + out_proj_bias=True, + dropout=0.0, + softmax_scale=None, + causal=False, + layer_idx=None, + dwconv=False, + rotary_emb_dim=0, + rotary_emb_base=10000.0, + rotary_emb_scale_base=None, + rotary_emb_interleaved=False, + use_alibi=False, + window_size=(-1, -1), + fused_bias_fc=False, + use_flash_attn=True, + return_residual=False, + checkpointing=False, + device=None, + dtype=None, + ) -> None: + """ + num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.cross_attn = cross_attn + self.causal = causal + self.layer_idx = layer_idx + self.dwconv = dwconv + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = True # use_flash_attn + self.return_residual = return_residual + self.checkpointing = checkpointing + if use_alibi: + assert use_flash_attn, "ALiBi code path requires flash_attn" + alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device) + else: + alibi_slopes = None + if window_size != (-1, -1): + assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" + + self.num_heads = num_heads + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads + assert ( + self.num_heads % self.num_heads_kv == 0 + ), "num_heads must be divisible by num_heads_kv" + assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + kv_dim = 2 * self.head_dim * self.num_heads_kv + + if self.rotary_emb_dim > 0: + assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet" + assert RotaryEmbedding is not None, "rotary_emb is not installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, + base=rotary_emb_base, + scale_base=rotary_emb_scale_base, + interleaved=rotary_emb_interleaved, + device=device, + ) + + if fused_bias_fc and FusedDense is None: + raise ImportError("fused_dense is not installed") + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + linear_resid_cls = ( + LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) + ) + wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls + inner_attn_cls = ( + partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) + if use_flash_attn + else SelfAttention + ) + + self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) + + self.inner_attn = inner_attn_cls( + causal=causal, + softmax_scale=softmax_scale, + attention_dropout=dropout, + ) + self.softmax_scale = softmax_scale + self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + return torch.empty( + batch_size, + max_seqlen, + 2, + self.num_heads_kv, + self.head_dim, + dtype=dtype, + device=device, + ) + + def _update_kv_cache(self, kv, inference_params): + """Update kv cache in inference_params.""" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def forward( + self, + x, + x_kv=None, + key_padding_mask=None, + cu_seqlens=None, + max_seqlen=None, + mixer_subset=None, + inference_params=None, + batch_head_idx=None, + **kwargs, + ): + """ + Arguments: + x: (batch, seqlen, hidden_dim) + batch_head_idx: Tensor of indices specifying which batch and head indices to select. + Shape: (batch_size, top_k) + inference_params: for generation. + """ + seqlen_offset = ( + 0 + if inference_params is None + else ( + inference_params.lengths_per_sample + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None + batch, seqlen = x.shape[:2] + + if not self.cross_attn and self.num_heads_kv == self.num_heads: + # Self-attention, no MQA/GQA + assert x_kv is None and mixer_subset is None + if not self.return_residual: + qkv = self.Wqkv(x) + else: + qkv, x = self.Wqkv(x) + qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) + + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + + if inference_params is None or inference_params.seqlen_offset == 0: + # Inference stage without inference_params + if inference_params is not None: + # Update kv cache during prefill + kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) + + context = self.inner_attn(qkv, **kwargs) + + else: + # Generation stage + if batch_head_idx is None: + # Apply select attention without kv cache update + context = self._update_kvcache_attention(q = qkv[:, :, 0], kv = qkv[:, :, 1:], inference_params = inference_params) + else: + # Apply select attention with kv cache update + context = self._update_kvcache_select_attn(q = qkv[:, :, 0], kv = qkv[:, :, 1:], inference_params = inference_params, batch_head_idx = batch_head_idx) + + else: + raise NotImplementedError("SelectMHA currently supports only self-attention without MQA/GQA.") + + out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) + return out if not self.return_residual else (out, x) + + def _update_kvcache_select_attn(self, q, kv, inference_params, batch_head_idx): + """ + Apply select attention during generation stage. + + q: (batch_size, seqlen=1, n_heads, head_dim) + kv: (batch_size, seqlen=1, 2, n_heads, head_dim) + batch_head_idx: Tensor of indices specifying which batch and head indices to select. + Shape: (batch_size, top_k) + + # currently only supports batches with same seqlen + # different seqlen requires a simple update in the select_attn kernel to load the seqlen, future work + """ + # check batch_head_idx shape + # assert batch_head_idx.shape[0] == 2, "batch_head_idx must have shape (N_selected, 2)" + # check batch_head_idx is not None + assert batch_head_idx is not None, "batch_head_idx must not be None" + + # update kv cache + kv_cache = self._update_kv_cache(kv, inference_params) + # inference_params.seqlen_offset += 1 # if seqlen_offset is int + + batch = q.shape[0] + # kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + + # make sure seqlen_offset accounts for the current token + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + 1 # +1 for the current token + ) + + # need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim) + q = q.unsqueeze(2) + k_cache = kv_cache[:, :, 0].unsqueeze(2) + v_cache = kv_cache[:, :, 1].unsqueeze(2) + + # Call select_attn + context = select_attn( + q, + k_cache, + v_cache, + self.softmax_scale, + batch_head_idx, + cache_seqlens) + + # context: (batch_size, seqlen_q=1, G=1, H, head_dim) + # context = context.squeeze(2) # Remove G dimension + batch_size = batch_head_idx.shape[0] + context = context.view(batch_size, 1, self.num_heads, self.head_dim) + + return context + + def _update_kvcache_attention(self, q, kv, inference_params): + """Write kv to inference_params, then do attention""" + if ( + inference_params.seqlen_offset == 0 + or flash_attn_with_kvcache is None + or not self.use_flash_attn + ): + # TODO: this only uses seqlen_offset and not lengths_per_sample. + kv = self._update_kv_cache(kv, inference_params) + return self.inner_cross_attn(q, kv) + else: + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + # alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) + alibi_slopes = None + return flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_attn.softmax_scale, + causal=self.inner_attn.causal, + alibi_slopes=alibi_slopes, + ) + + # dummy function for testing + def _select_attn(self, q, kv, inference_params, batch_head_idx): + """ + Apply select attention during generation stage. + + q: (batch_size, seqlen=1, n_heads, head_dim) + kv: (batch_size, seqlen=1, 2, n_heads, head_dim) + batch_head_idx: Tensor of indices specifying which batch and head indices to select. + Shape: (N_selected, 2) + + # currently only supports batches with same seqlen + # different seqlen requires a simple update in the select_attn kernel to load the seqlen, future work + """ + # check batch_head_idx shape + assert batch_head_idx.shape[1] == 2, "batch_head_idx must have shape (N_selected, 2)" + + # update kv cache + # kv_cache = self._update_kv_cache(kv, inference_params) + # inference_params.seqlen_offset += 1 # if seqlen_offset is int + + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + + # make sure seqlen_offset accounts for the current token + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset # +1 for the current token + ) + + # need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim) + q = q.unsqueeze(2) + k_cache = kv_cache[:, :, 0].unsqueeze(2) + v_cache = kv_cache[:, :, 1].unsqueeze(2) + + # Call select_attn + context = select_attn( + q, + k_cache, + v_cache, + self.softmax_scale, + batch_head_idx, + cache_seqlens) + + # context: (batch_size, seqlen_q=1, G=1, H, head_dim) + context = context.squeeze(2) # Remove G dimension + + return context + + +# SelectiveGQA: Future work + +class ParallelSelectMHA(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + process_group, + num_heads_kv=None, + qkv_proj_bias=True, + out_proj_bias=True, + dropout=0.0, + softmax_scale=None, + causal=True, + layer_idx=None, + dwconv=False, + rotary_emb_dim=0, + rotary_emb_base=10000.0, + rotary_emb_scale_base=None, + rotary_emb_interleaved=False, + use_alibi=False, + window_size=(-1, -1), + fused_bias_fc=True, + use_flash_attn=True, + return_residual=False, + checkpointing=False, + sequence_parallel=False, + device=None, + dtype=None, + ) -> None: + """ + num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.causal = causal + self.layer_idx = layer_idx + self.dwconv = dwconv + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.return_residual = return_residual + self.checkpointing = checkpointing + self.process_group = process_group + self.world_size = process_group.size() + self.local_rank = torch.distributed.get_rank(process_group) + + self.num_heads = num_heads + assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" + + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads + assert ( + self.num_heads % self.num_heads_kv == 0 + ), "num_heads must be divisible by num_heads_kv" + + self.num_heads_per_rank = get_dim_for_local_rank( + self.num_heads, self.world_size, self.local_rank + ) + self.num_heads_kv_per_rank = get_dim_for_local_rank( + self.num_heads_kv, self.world_size, self.local_rank + ) + self.head_dim = self.embed_dim // num_heads + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + + if use_alibi: + assert use_flash_attn, "ALiBi code path requires flash_attn" + num_heads_local = math.ceil(self.num_heads / self.world_size) + alibi_slopes = torch.tensor( + get_alibi_slopes(num_heads)[ + self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local + ], + device=device, + ) + else: + alibi_slopes = None + if window_size != (-1, -1): + assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" + + if self.rotary_emb_dim > 0: + assert RotaryEmbedding is not None, "rotary_emb is not installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, + base=rotary_emb_base, + scale_base=rotary_emb_scale_base, + interleaved=rotary_emb_interleaved, + device=device, + ) + + if ColumnParallelLinear is None or RowParallelLinear is None: + raise ImportError("fused_dense is not installed") + self.Wqkv = ColumnParallelLinear( + embed_dim, + qkv_dim, + process_group, + bias=qkv_proj_bias, + sequence_parallel=sequence_parallel, + multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2), + **factory_kwargs, + ) + + inner_attn_cls = ( + partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) + if use_flash_attn + else SelfAttention + ) + + self.inner_attn = inner_attn_cls( + causal=causal, + softmax_scale=softmax_scale, + attention_dropout=dropout, + ) + self.softmax_scale = softmax_scale + + # replace this with no reduce + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + process_group, + bias=out_proj_bias, + sequence_parallel=sequence_parallel, + multiple_of=self.head_dim, + **factory_kwargs, + ) + + self.mha_router = None + self.mlp_router = None + # We'll use an extra stream for concurrency + self.current_stream = None + self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0) + self.main_stream = torch.cuda.Stream(device="cuda", priority=-5) + self.mha_router_event = torch.cuda.Event(enable_timing=False, blocking=False) + self.mlp_router_event = torch.cuda.Event(enable_timing=False, blocking=False) + self.main_event = torch.cuda.Event(enable_timing=False, blocking=False) + + # self.local_head_idx = generate_random_BH_index(1, self.num_heads_per_rank,self.num_heads_per_rank) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + return torch.empty( + batch_size, + max_seqlen, + 2, + self.num_heads_kv_per_rank, + self.head_dim, + dtype=dtype, + device=device, + ) + + def _update_kv_cache(self, kv, inference_params): + """Update kv cache in inference_params.""" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def forward( + self, + x, + seqlen=None, + inference_params=None, + batch_head_idx=None, + **kwargs, + ): + """ + Arguments: + x: (batch, seqlen, hidden_dim) + batch_head_idx: Tensor of indices specifying which batch and head indices to select. + Shape: (N_selected,) + inference_params: for generation. + """ + + router_inputs = x.squeeze(1) + self.current_stream = torch.cuda.current_stream() + self.main_stream.wait_stream(self.current_stream ) + self.sparse_stream.wait_stream(self.current_stream ) + + is_decode = inference_params is not None and inference_params.seqlen_offset > 0 + + # if self.mha_router and is_decode: + # with torch.cuda.stream(self.sparse_stream): + # batch_head_idx = self.mha_router._select_heads(router_inputs) + # self.sparse_stream.record_event(self.mha_router_event) + + + qkv = self.Wqkv(x) + if seqlen is not None: + qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) + + seqlen_offset = ( + 0 + if inference_params is None + else ( + inference_params.lengths_per_sample + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None + batch, seqlen = x.shape[:2] + + if self.num_heads_kv == self.num_heads: + # Self-attention, no MQA/GQA + qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) + + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + if inference_params is None or inference_params.seqlen_offset == 0: + # Inference stage without inference_params, prefill stage + if inference_params is not None: + # Update kv cache during prefill + kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) + + context = self.inner_attn(qkv, **kwargs) + else: + # Generation stage + + # apply rotary embeddings + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + + # Apply select attention with kv cache update + context = self._update_kvcache_select_attn(qkv[:, :, 0], qkv[:, :, 1:], inference_params, batch_head_idx) + else: # cross-attention, MQA/GQA + raise NotImplementedError("SelectMHA currently supports only self-attention without MQA/GQA.") + + context = rearrange(context, "b s h d -> b s (h d)") + if seqlen is not None: + context = rearrange(context, "b s d -> (b s) d") + # out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) + # out = self.out_proj(context) + + out = fused_dense_func(context, self.out_proj.weight, self.out_proj.bias) + + # if is_decode: + # if self.mlp_router: + # with torch.cuda.stream(self.sparse_stream): + # index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk) + # self.sparse_stream.record_event(self.mlp_router_event) + + # with torch.cuda.stream(self.main_stream): + # out = all_reduce(out, self.process_group) + # self.main_stream.record_event(self.main_event) + + # self.current_stream.wait_event(self.mlp_router_event) + # self.current_stream.wait_event(self.main_event) + + # # index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk) + # # out = all_reduce(out, self.process_group) + + # return out, index_vec + # else: + # out = all_reduce(out, self.process_group) + out = all_reduce(out, self.process_group) + return out if not self.return_residual else (out, x) + # return out + + def _update_kvcache_select_attn(self, q, kv, inference_params, batch_head_idx = None): + """ + Apply select attention during generation stage. + + q: (batch_size, seqlen=1, n_heads, head_dim) + kv: (batch_size, seqlen=1, 2, n_heads, head_dim) + batch_head_idx: Tensor of indices specifying which batch and head indices to select. + Shape: (batch_size, top_k) + """ + # check batch_head_idx shape + # assert batch_head_idx.shape[1] == 2, "batch_head_idx must have shape (N_selected, 2)" + + # if batch_head_idx is None: + # batch_head_idx = self.local_head_idx + # print("Using local_head_idx, router not used.") + # batch_head_idx = self.local_head_idx + # print("Using local_head_idx, router not used.") + + # update kv cache + kv_cache = self._update_kv_cache(kv, inference_params) + + batch = q.shape[0] + # kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + 1 # +1 for the current token + ) + # need to reshape or view keys and value with shape (batch_size, seqlen, 1, n_heads, head_dim) + q = q.unsqueeze(2) + k_cache = kv_cache[:, :, 0].unsqueeze(2) + v_cache = kv_cache[:, :, 1].unsqueeze(2) + + self.current_stream.wait_event(self.mha_router_event) + + assert batch_head_idx is not None, "batch_head_idx must not be None" + # Call select_attn + context = select_attn( + q, + k_cache, + v_cache, + self.softmax_scale, + batch_head_idx, + cache_seqlens + ) + + # context: (batch_size, seqlen_q=1, G=1, H, head_dim) + # context = context.squeeze(2) # Remove G dimension + context = context.view(batch, 1, self.num_heads_kv_per_rank, self.head_dim) + return context + +''' +PYTHONWARNINGS="ignore" python -m HybridTensor.modules.SelectiveMHA --batch_size 8 --in_features 8192 --seq_len 512 --head_density 0.25 +''' + + +if __name__ == "__main__": + args = arg_parser() + + max_seqlen = args.seq_len + 128 + max_batch_size = args.batch_size + device = torch.device(f"cuda:{args.device}") + + # simulates SelectiveMHA inference generation stage + inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size) + nheads = args.in_features // 128 + softmax_scale = 1 / (128 ** 0.5) + rotary_emb_dim = 0 + + # build SelectiveMHA + select_mha = SelectMHA( + embed_dim=args.in_features, + num_heads=nheads, + num_heads_kv=None, + causal=True, + layer_idx=0, + use_flash_attn=True, + softmax_scale=softmax_scale, + return_residual=False, + rotary_emb_dim=rotary_emb_dim, + device=device, + dtype=torch.float16, + ) + + standard_mha = SMHA( + embed_dim=args.in_features, + num_heads=nheads, + num_heads_kv=None, + causal=True, + layer_idx=0, + use_flash_attn=True, + softmax_scale=softmax_scale, + return_residual=False, + rotary_emb_dim=rotary_emb_dim, + device=device, + dtype=torch.float16, + ) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + with torch.no_grad(): + # prefill stage to generate kv cache for all batches + og_x = torch.randn(args.batch_size, args.seq_len, args.in_features, device=device, dtype=torch.float16, requires_grad=False) + + # out, time_ms = cuda_profiler(select_mha, og_x, inference_params=inference_params) + # print(f"MHA Prefill time: {time_ms:.3f} ms") + # out = select_mha(og_x, inference_params=inference_params) + + # simulate kv cache, bug in flash_attn for larger batches + kv = torch.randn(args.batch_size, args.seq_len, 2, nheads, 128, device=device, dtype=torch.float16, requires_grad=False) + _ = _update_kv_cache(kv, inference_params, 0) + + # increment the sequence length to move to the generation stage + inference_params.seqlen_offset += args.seq_len + + input_x = torch.randn(args.batch_size, 1, args.in_features, device=device, dtype=torch.float16, requires_grad=False) + selected_heads = math.ceil(nheads * args.head_density) + + # generate batch_head_idx for SelectiveMHA + # batch_head_index = generate_BH_index(args.batch_size, nheads, selected_heads, device=device) + batch_head_index = generate_random_BH_index(args.batch_size, nheads, selected_heads, device=device) + # generatation stage Standard MHA + out, standard_time_ms = cuda_profiler(standard_mha, input_x, inference_params=inference_params) + print(f"Standard MHA time: {standard_time_ms:.3f} ms") + + # generatation stage SelectiveMHA + out, select_time_ms = cuda_profiler(select_mha, input_x, inference_params=inference_params, batch_head_idx=batch_head_index) + print(f"SelectMHA time: {select_time_ms:.3f} ms") + + speedup = standard_time_ms / select_time_ms + print(f"Speedup: {speedup:.3f}") + diff --git a/HybridTensor/modules/SelectiveMLP.py b/HybridTensor/modules/SelectiveMLP.py new file mode 100644 index 0000000000000000000000000000000000000000..e54573f16fb9183ee2a8cf0fe32d3813710ee958 --- /dev/null +++ b/HybridTensor/modules/SelectiveMLP.py @@ -0,0 +1,580 @@ +# python -m HybridTensor.modules.SelectiveMLP --batch_size 8 --index_size 512 +from typing import Optional +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +import torch.distributed as dist + +# import fused_dense_cuda # from apex + +import fused_dense_lib as fused_dense_cuda + +from flash_attn.utils.distributed import reduce_scatter, all_reduce +from einops import rearrange + +# from HybridTensor.modules.MLP import SelectiveMLPFunc +from HybridTensor.modules.references.fused_dense import ColumnParallelLinear, RowParallelLinear, fused_mlp_func +from HybridTensor.modules.references.MLP import SelectiveMLPTriton +from HybridTensor.utils.utils import arg_parser, sparse_index +from HybridTensor.utils.profiling import cuda_profiler + +# compiles the kernels for the first time, takes time +from HybridTensor.triton.gather_gemm_col import gather_matmul_col +from HybridTensor.triton.gather_gemm_row import gather_matmul_row + +# needs to be compiled before running +from HybridTensor.triton.heuristics.gather_gemm_col_h import gather_matmul_col as gather_matmul_col_h +from HybridTensor.triton.heuristics.gather_gemm_row_h import gather_matmul_row as gather_matmul_row_h + +# from HybridTensor.triton.cg_safe.gather_gemm_col_cg import gather_matmul_col +# from HybridTensor.triton.cg_safe.gather_gemm_row_cg import gather_matmul_row + + +def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, bias1 = None, bias2 = None, activation='relu', use_heuristic=True): + if use_heuristic: + out = gather_matmul_col_h(x, fc1_w, index_vec, bias = bias1, activations=activation) + out = gather_matmul_row_h(out, fc2_w, index_vec, bias = bias2) + else: + out = gather_matmul_col(x, fc1_w, index_vec, bias = bias1, activations=activation) + out = gather_matmul_row(out, fc2_w, index_vec, bias = bias2) + return out + + +# cg safe version +# def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, index_size, bias1 = None, bias2 = None, activation='relu', use_heuristic=True): +# out = gather_matmul_col(x, fc1_w, index_vec, index_size, bias = bias1, activations=activation) +# out = gather_matmul_row(out, fc2_w, index_vec, index_size, bias = bias2) +# return out + +class MLPRouter(nn.Module): + def __init__(self, embed_dim, low_rank_dim, out_dim, act_th, device=None, dtype=None): + """ + Initializes the MHARouter class. + + Args: + embed_dim (int): Dimensionality of the input embeddings. + low_rank_dim (int): Dimensionality of the intermediate layer. + out_dim (int): Number of neurons. + """ + super(MLPRouter, self).__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.fc1 = nn.Linear(embed_dim, low_rank_dim, bias=False, **factory_kwargs) + self.fc2 = nn.Linear(low_rank_dim, out_dim, bias=False, **factory_kwargs) + self.act_th = act_th + self.num_neurons = out_dim + self.largest = self.num_neurons + 1 + + def forward(self, x): + """ + Forward pass of the MHARouter. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, embed_dim). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, num_heads). + """ + x = self.fc1(x) + x = self.fc2(x) + return x + + def _select_neurons_topk(self, x, topk=None): + neurons = self.forward(x) + + neurons_nonzero = torch.nn.ReLU()(neurons) + _, index_vec = neurons_nonzero.sum(dim=0).topk(topk, dim=0, sorted=False) + # index_vec, _ = index_vec.sort() + return index_vec + + def _select_neurons(self, x, th=None): + ''' + Threshold based selection of neurons, not CG safe + ''' + if th is None: + th = self.act_th + + neurons = self.forward(x) + activated = (neurons > th).sum(dim=0) + index_vec = activated.nonzero().flatten() + return index_vec + + def _select_neurons_cuda_safe(self, x, th=None): + ''' + This function is used with threshold and is used for CG safe version of the code + ''' + if th is None: + th = self.act_th + neurons = self.forward(x) + activated = (neurons > th).sum(dim=0) + + indices = torch.arange(self.num_neurons, device=activated.device) + selected = torch.where(activated > th, indices, torch.full_like(indices, self.largest)) + + index_vec, _ = torch.sort(selected) + index_size = ((index_vec < self.largest).sum()).to(torch.int32) + + return index_size, index_vec + + + +class ParallelMLPRouter(nn.Module): + """ + Parallel Sparse Predictor for MHA layer. + """ + + def __init__( + self, + embed_dim, + low_rank_dim, + out_dim, + act_th, + process_group, + sequence_parallel=False, + device=None, + dtype=None, + ): + """ + Initializes the ParallelMHARouter class. + + Args: + embed_dim (int): Dimensionality of the input embeddings. + low_rank_dim (int): Dimensionality of the intermediate layer. + out_dim (int): Output dimensionality (typically number of neurons). + process_group (torch.distributed.ProcessGroup): Process group for parallelism. + sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False. + device (torch.device, optional): Device to run the module on. Defaults to None. + dtype (torch.dtype, optional): Data type of the module parameters. Defaults to None. + """ + super(ParallelMLPRouter, self).__init__() + assert process_group is not None, "ParallelMHARouter requires a process group." + + factory_kwargs = {"device": device, "dtype": dtype} + self.process_group = process_group + self.embed_dim = embed_dim + self.act_th = act_th + + self.fc1 = nn.Linear( + embed_dim, low_rank_dim, bias=False, **factory_kwargs + ) + self.fc2 = ColumnParallelLinear( + low_rank_dim, + out_dim, + process_group, + bias=False, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + + # def _select_neurons(self, neurons, th=None): + # if th is None: + # th = self.act_th + # activated = (neurons > th).sum(dim=0) + # index_vec = activated.nonzero().flatten() + # return index_vec + + def forward(self, x): + """ + Forward pass of the ParallelMHARouter. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, embed_dim). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, seq_len, out_dim). + """ + x = self.fc1(x) + x = self.fc2(x) + return x + + def _select_neurons(self, x, th=None): + if th is None: + th = self.act_th + + neurons = self.forward(x) + activated = (neurons > th).sum(dim=0) + index_vec = activated.nonzero().flatten() + return index_vec + + def _select_neurons_topk(self, x, topk=None): + neurons = self.forward(x) + + neurons_nonzero = torch.nn.ReLU()(neurons) #.squeeze(1) + # print(f"neurons_nonzero shape: {neurons_nonzero.shape}") + # print(f"Top k neurons: {topk}") + _, index_vec = neurons_nonzero.sum(dim=0).topk(topk, dim=0, sorted=False) + # index_vec, _ = index_vec.sort() + return index_vec + +class SelectiveMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation='relu', + layer_idx=None, + bias1=True, + bias2=True, + return_residual=False, + checkpoint_lvl=0, + use_heuristic=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features if out_features is not None else in_features + hidden_features = hidden_features if hidden_features is not None else in_features * 4 + self.return_residual = return_residual + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) + self.activation = activation + self.activation_fn = nn.ReLU() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) + # self.fc2_weight_t = self.fc2.weight.t().contiguous() + self.fc2_weight_t = None + self.use_heuristic = use_heuristic + + def _init_weights(self): + # if weights are updated, we need to update the transpose + self.fc2_weight_t = self.fc2.weight.t().contiguous() + + def forward(self, x, index_vec=None, index_size=None): + + if index_vec is not None: + # sparse forward, + + # update on first run + if self.fc2_weight_t is None: + self.fc2_weight_t = self.fc2.weight.t().contiguous() + + # Remove the original parameter to free memory. + self.fc2.weight = None + del self.fc2._parameters['weight'] + + x = x.view(-1, x.size(-1)) + # x = x.squeeze(1) + y = SelectiveMLPFunc(x = x, fc1_w = self.fc1.weight, + fc2_w = self.fc2_weight_t, index_vec = index_vec, + bias1 = self.fc1.bias, bias2 = self.fc2.bias, + activation=self.activation, use_heuristic=self.use_heuristic) + + else: + # dense forward + + y = self.fc1(x) + y = self.activation_fn(y) + + if self.fc2_weight_t is not None: + y = torch.matmul(y, self.fc2_weight_t) + else: + y = self.fc2(y) + + return y if not self.return_residual else (y, x) + +class ParallelSelectiveMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features, + out_features=None, + activation="relu", + layer_idx=None, + process_group: ProcessGroup = None, + bias1=True, + bias2=True, + return_residual=False, + sequence_parallel=False, + use_heuristic=True, + checkpoint_lvl=0, + heuristic="auto", + device=None, + dtype=None, + ): + """ + process_group is required. We're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul, gelu, then matmul. + Finally we do a reduce_scatter of the output. + + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute pre_act and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + 'auto': heuristic will be picked automatically: + For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. + """ + assert checkpoint_lvl in [0, 1, 2] + assert activation in ["gelu_approx", "relu"] + assert process_group is not None + # assert sp_kwargs != None, "sparse predictor parameters are not passed in." + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if out_features is None: + out_features = in_features + self.activation = activation + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic + self.fc1 = ColumnParallelLinear( + in_features, hidden_features, process_group, bias=bias1, **factory_kwargs + ) + self.fc2 = RowParallelLinear( + hidden_features, out_features, process_group, bias=bias2, **factory_kwargs + ) + self.layer_idx = layer_idx + + self.fc2_weight_t = self.register_buffer("fc2_weigth_t", None) + self.return_residual = return_residual + self.fc2_weight_t = None + self.use_heuristic = use_heuristic + self.reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + + # self._init_weights() + + def _init_weights(self): + # ffn2 weights needs to be in row major format to select from rows + self.fc2_weight_t = self.fc2.weight.t().contiguous() + + def forward(self, x, residual = None, index_vec = None): + + # do_token_generation = x.size(1) == 1 + # index_vec = None + # with torch.cuda.stream(self.curr_stream): + if index_vec is not None: + # assert x.size(1) == 1 + if self.fc2_weight_t is None: + self.fc2_weight_t = self.fc2.weight.t().contiguous() + + x = x.view(-1, x.size(-1)) + # x = rearrange(x, "b 1 d -> b d") # slightly more expensive to use rearrange + + out = SelectiveMLPFunc(x = x, fc1_w = self.fc1.weight, + fc2_w = self.fc2_weight_t, index_vec = index_vec, + bias1 = self.fc1.bias, bias2 = self.fc2.bias, + activation=self.activation, use_heuristic=self.use_heuristic) + # out = rearrange(out, "b d -> b 1 d") + # out = out.view(-1, 1, out.size(-1)) + + else: # normal mlp + if self.heuristic == "auto": + dtype = ( + x.dtype + if not torch.is_autocast_enabled() + else torch.get_autocast_gpu_dtype() + ) + if self.activation == "gelu_approx": + cuda_ver = tuple(map(int, torch.version.cuda.split("."))) + heuristic = ( + 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + ) + else: + heuristic = 0 + else: + heuristic = self.heuristic + out = fused_mlp_func( + x, + self.fc1.weight, + self.fc2.weight, + self.fc1.bias, + self.fc2.bias, + activation=self.activation, + save_pre_act=self.training, + checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + ) + + if self.process_group.size() > 1: + # out = self.reduce_fn(out, self.process_group) # has some overhead, + dist.all_reduce(out, op=dist.ReduceOp.SUM, group=self.process_group) + + return out if not self.return_residual else (out, x) + # return out + + def sp_forward(self, x, residual = None, index_vec = None): + if self.heuristic == "auto": + dtype = ( + x.dtype + if not torch.is_autocast_enabled() + else torch.get_autocast_gpu_dtype() + ) + if self.activation == "gelu_approx": + cuda_ver = tuple(map(int, torch.version.cuda.split("."))) + heuristic = ( + 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + ) + else: + heuristic = 0 + else: + heuristic = self.heuristic + curr_stream = torch.cuda.current_stream() + do_token_generation = x.size(1) == 1 + # mlp_logit = None + + # with torch.cuda.stream(self.curr_stream): + if index_vec != None: + assert x.size(1) == 1 + + if self.fc2_weight_t is None: + self.fc2_weight_t = self.fc2.weight.t().contiguous() + + out = SelectiveMLPFunc( + rearrange(x, "b 1 d -> b d"), + self.fc1.weight, + self.fc2_weight_t, + index_vec, + self.fc1.bias, + self.fc2.bias, + activation=self.activation, + ) + out = rearrange(out, "b d -> b 1 d") + else: + out = fused_mlp_func( + x, + self.fc1.weight, + self.fc2.weight, + self.fc1.bias, + self.fc2.bias, + activation=self.activation, + save_pre_act=self.training, + checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + ) + + + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + if self.sp_router: + curr_stream.record_event(self.event_mlp) + + # handle = torch.distributed.all_reduce(out, op=torch.distributed.ReduceOp.SUM, group=self.process_group, async_op=True) + out = reduce_fn(out, self.process_group) + + + if self.sp_router: + with torch.cuda.stream(self.sp_stream): + self.sp_stream.wait_event(self.event_mlp) + if do_token_generation: + mlp_logit = self.sp(rearrange(residual, "b 1 d -> b d")) + self.sp_stream.record_event(self.event_mlp_sp) + + # check this again, we might not have to synchronize here, we can synchronize in the next layer + curr_stream.wait_event(self.event_mlp_sp) + + return out + +class SimpleMLP(nn.Module): + def __init__(self, in_features, hidden_features, out_features, bias=False, activation="relu"): + super().__init__() + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.activation = activation + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + +if __name__ == "__main__": + args = arg_parser() + + bias = True if args.bias > 0 else False + x = torch.randn(args.batch_size, args.in_features, device="cuda", dtype=torch.float16) + index_vec, _ = sparse_index(args.index_size, args.in_features*4) + + ''' + selective_mlp = SelectiveMLPTriton(args.in_features, args.hidden_features, bias=bias, device="cuda", dtype=torch.float16, activation="relu") + + out, mlp_time = cuda_profiler(selective_mlp, x, index_vec) + + out_col, col_time = cuda_profiler(gather_matmul_col, x, selective_mlp.fc1_w, index_vec, activations=selective_mlp.activation) + out_row, row_time = cuda_profiler(gather_matmul_row, out_col, selective_mlp.fc2_w, index_vec) + sum_time = col_time + row_time + + print(f"Index size {args.index_size}, Activated {args.index_size/(args.in_features * 4)*100}% neurons") + + print(f"Gather Col Time: {col_time} ms") + print(f"Gather Row Time: {row_time} ms") + # print(f"Sum Time: {sum_time} ms") + + print(f"SelectiveMLP Time: {mlp_time} ms") + ''' + + in_features = args.in_features + hidden_features = in_features * 4 + out_features = in_features + device = torch.device("cuda") + + model = SelectiveMLP( + in_features, hidden_features, out_features, device=device, dtype=torch.float16, activation="relu", use_heuristic=True + ).to(device) + + router = MLPRouter(in_features, 1024, hidden_features, act_th = 0.5, device=device, dtype=torch.float16).to(device) + + # Warm-up GPU + def warmup(): + for _ in range(10): + _ = model(x, index_vec) + _ = model(x, None) + _ = router._select_neurons_topk(x, args.index_size) + + warmup() + + # Measure SelectiveMLPFunc speed + _, router_time = cuda_profiler(router._select_neurons_topk, x, args.index_size) + _, selective_time = cuda_profiler(model, x, index_vec) + # Measure dense forward speed + _, dense_time = cuda_profiler(model, x, None) + + print(f"Router time per run: {router_time:.6f} ms") + print(f"SelectiveMLPFunc time per run: {selective_time:.6f} ms") + print(f"Dense forward time per run: {dense_time:.6f} ms") + print(f"Speedup: {dense_time / selective_time:.2f}x") + router_selective_time = router_time + selective_time + print(f"Router + SelectiveMLPFunc time per run: {router_selective_time:.6f} ms") + print(f"Speedup: {dense_time / router_selective_time:.2f}x") + ############################################ + # CUDA Graph capture tests for the MLP model + ############################################ + print("\n=== CUDA Graph Tests ===") + # --- Selective forward (sparse mode) --- + print("Testing CUDA Graph for Selective forward (with index_vec)...") + static_x = x.clone() + static_index_vec = index_vec.clone() + # Warm-up run to allocate memory + static_out_sel = model(static_x, index_vec=static_index_vec) + torch.cuda.synchronize() + + # Capture on a non-default stream + capture_stream = torch.cuda.Stream() + with torch.cuda.stream(capture_stream): + g_sel = torch.cuda.CUDAGraph() + g_sel.capture_begin() + static_out_sel = model(static_x, index_vec=static_index_vec) + g_sel.capture_end() + torch.cuda.synchronize() + + # Replay and check accuracy + g_sel.replay() + torch.cuda.synchronize() + cuda_sel_out = static_out_sel.clone() + regular_sel_out = model(x, index_vec=index_vec) + if torch.allclose(cuda_sel_out, regular_sel_out, atol=1e-3): + print("Selective forward CUDA Graph output matches regular output") + else: + print("Selective forward CUDA Graph output does NOT match regular output") + + def replay_sel(): + g_sel.replay() + _, selective_time_cuda = cuda_profiler(replay_sel) + print(f"Selective forward CUDA Graph time per run: {selective_time_cuda:.6f} ms") \ No newline at end of file diff --git a/HybridTensor/modules/SelectiveRouters.py b/HybridTensor/modules/SelectiveRouters.py new file mode 100644 index 0000000000000000000000000000000000000000..4f3c9d4aaedfdcd2984d8d6069c09789812ee181 --- /dev/null +++ b/HybridTensor/modules/SelectiveRouters.py @@ -0,0 +1,136 @@ +import os +import re +import torch +from collections import OrderedDict + +def create_mlp_router_state_dict(router_files_dir): + """ + Loads all mlp_router weight files from the specified directory and creates a router_state_dict + with keys formatted as 'transformer.layers.{layer_num}.mlp_router.{param_name}'. + + Args: + router_files_dir (str): Path to the directory containing mlp_router_*.pt files. + + Returns: + OrderedDict: A state dictionary suitable for loading into a transformer model. + """ + # Regular expression to extract layer number from filename + router_file_pattern = re.compile(r'mlp_router_(\d+)-[\d.]+-[\d.]+-[\d.]+\.pt$') + + router_state_dict = OrderedDict() + + # List all files in the directory + try: + all_files = os.listdir(router_files_dir) + except FileNotFoundError: + print(f"Error: Directory '{router_files_dir}' does not exist.") + return None + + # Filter files matching the pattern + router_files = [f for f in all_files if router_file_pattern.match(f)] + + if not router_files: + print(f"No router files found in directory '{router_files_dir}'.") + return None + + for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))): + match = router_file_pattern.match(file_name) + if not match: + print(f"Skipping file '{file_name}' as it does not match the pattern.") + continue + + layer_num = int(match.group(1)) + file_path = os.path.join(router_files_dir, file_name) + + try: + # Load the router's state dict + router_weights = torch.load(file_path, map_location='cpu') + if not isinstance(router_weights, dict): + print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.") + continue + except Exception as e: + print(f"Error loading '{file_path}': {e}") + continue + + # Iterate through each parameter in the router's state dict + for param_name, param_tensor in router_weights.items(): + # Construct the new key + new_key = f"transformer.layers.{layer_num}.mlp_router.{param_name}" + router_state_dict[new_key] = param_tensor + + # print(f"Loaded router for layer {layer_num} from '{file_name}'.") + + print(f"Total routers loaded: {len(router_state_dict) // 2}") # Assuming 4 params per router (weight & bias for 2 layers) + return router_state_dict + + +def create_attn_router_state_dict(router_files_dir): + """ + Loads all attn_router weight files from the specified directory and creates a router_state_dict + with keys formatted as 'transformer.layers.{layer_num}.mha_router.{param_name}'. + + Args: + router_files_dir (str): Path to the directory containing attn_router_*.pt files. + + Returns: + OrderedDict: A state dictionary suitable for loading into a transformer model. + """ + # Regular expression to extract layer number from filename + # Pattern: attn_router_{layer_num}-{value1}-{value2}.pt + router_file_pattern = re.compile(r'attn_router_(\d+)-[\d.]+-[\d.]+\.pt$') + + router_state_dict = OrderedDict() + + # List all files in the directory + try: + all_files = os.listdir(router_files_dir) + except FileNotFoundError: + print(f"Error: Directory '{router_files_dir}' does not exist.") + return None + + # Filter files matching the pattern + router_files = [f for f in all_files if router_file_pattern.match(f)] + + if not router_files: + print(f"No attn_router files found in directory '{router_files_dir}'.") + return None + + # To handle potential duplicates, keep track of loaded layer numbers + loaded_layers = set() + + for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))): + match = router_file_pattern.match(file_name) + if not match: + print(f"Skipping file '{file_name}' as it does not match the pattern.") + continue + + layer_num = int(match.group(1)) + if layer_num in loaded_layers: + print(f"Warning: Multiple router files found for layer {layer_num}. Skipping '{file_name}'.") + continue # Skip duplicate layers + + file_path = os.path.join(router_files_dir, file_name) + + try: + # Load the router's state dict + router_weights = torch.load(file_path, map_location='cpu') + if not isinstance(router_weights, dict): + print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.") + continue + except Exception as e: + print(f"Error loading '{file_path}': {e}") + continue + + # Iterate through each parameter in the router's state dict + for param_name, param_tensor in router_weights.items(): + # Construct the new key + new_key = f"transformer.layers.{layer_num}.mha_router.{param_name}" + router_state_dict[new_key] = param_tensor + + loaded_layers.add(layer_num) + # print(f"Loaded MHA router for layer {layer_num} from '{file_name}'.") + + print(f"Total MHA routers loaded: {len(loaded_layers)}") + return router_state_dict + + diff --git a/HybridTensor/modules/__init__.py b/HybridTensor/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/HybridTensor/modules/__pycache__/MLP.cpython-39.pyc b/HybridTensor/modules/__pycache__/MLP.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b394bcbe041d32e0fb8b4164a21d8d3baecc0a2f Binary files /dev/null and b/HybridTensor/modules/__pycache__/MLP.cpython-39.pyc differ diff --git a/HybridTensor/modules/__pycache__/ParallelMLP.cpython-39.pyc b/HybridTensor/modules/__pycache__/ParallelMLP.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..012d570c0a823b7541707e80762fbec89845c144 Binary files /dev/null and b/HybridTensor/modules/__pycache__/ParallelMLP.cpython-39.pyc differ diff --git a/HybridTensor/modules/__pycache__/SelectiveBlock.cpython-39.pyc b/HybridTensor/modules/__pycache__/SelectiveBlock.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6306ddb338be53332d1cf1cf5c8cf13e202efaa0 Binary files /dev/null and b/HybridTensor/modules/__pycache__/SelectiveBlock.cpython-39.pyc differ diff --git a/HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-310.pyc b/HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff84fbdd5e51c53b9728de3d0cf18bd00c752e7c Binary files /dev/null and b/HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-310.pyc differ diff --git a/HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-39.pyc b/HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78171b89d0ced2ab94b2e7e9958a119dd7935403 Binary files /dev/null and b/HybridTensor/modules/__pycache__/SelectiveBlock_v1.cpython-39.pyc differ diff --git a/HybridTensor/modules/__pycache__/SelectiveMHA.cpython-310.pyc b/HybridTensor/modules/__pycache__/SelectiveMHA.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32ee46f1096cd47cc0566b058c9bdf078a3174be Binary files /dev/null and b/HybridTensor/modules/__pycache__/SelectiveMHA.cpython-310.pyc differ diff --git a/HybridTensor/modules/__pycache__/SelectiveMHA.cpython-39.pyc b/HybridTensor/modules/__pycache__/SelectiveMHA.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed69b65d486fb9e2f86d394bef31bee45a31a47d Binary files /dev/null and b/HybridTensor/modules/__pycache__/SelectiveMHA.cpython-39.pyc differ diff --git a/HybridTensor/modules/__pycache__/SelectiveMLP.cpython-310.pyc b/HybridTensor/modules/__pycache__/SelectiveMLP.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a137d1981a4c4f6f01bea2073b33b6d2fb0ad862 Binary files /dev/null and b/HybridTensor/modules/__pycache__/SelectiveMLP.cpython-310.pyc differ diff --git a/HybridTensor/modules/__pycache__/SelectiveMLP.cpython-39.pyc b/HybridTensor/modules/__pycache__/SelectiveMLP.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f687bda6df48e12e30435f6e75f3c1ad7e76e4b Binary files /dev/null and b/HybridTensor/modules/__pycache__/SelectiveMLP.cpython-39.pyc differ diff --git a/HybridTensor/modules/__pycache__/SelectiveRouters.cpython-310.pyc b/HybridTensor/modules/__pycache__/SelectiveRouters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e479eff939081864e4cf096ec68d24fe952aea7 Binary files /dev/null and b/HybridTensor/modules/__pycache__/SelectiveRouters.cpython-310.pyc differ diff --git a/HybridTensor/modules/__pycache__/SelectiveRouters.cpython-39.pyc b/HybridTensor/modules/__pycache__/SelectiveRouters.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77e5f5c7afe64db844c128fc4421c58f3addaccd Binary files /dev/null and b/HybridTensor/modules/__pycache__/SelectiveRouters.cpython-39.pyc differ diff --git a/HybridTensor/modules/__pycache__/__init__.cpython-310.pyc b/HybridTensor/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6621013bb1717391971bf33ed18717738b91e01 Binary files /dev/null and b/HybridTensor/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/HybridTensor/modules/__pycache__/__init__.cpython-39.pyc b/HybridTensor/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..716149c7226b39190830b7d4a89c236f9e2a7e0b Binary files /dev/null and b/HybridTensor/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/HybridTensor/modules/__pycache__/fused_dense.cpython-39.pyc b/HybridTensor/modules/__pycache__/fused_dense.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6f47cc5449d452ba3980338252de3e003eaee85 Binary files /dev/null and b/HybridTensor/modules/__pycache__/fused_dense.cpython-39.pyc differ diff --git a/HybridTensor/modules/references/MLP.py b/HybridTensor/modules/references/MLP.py new file mode 100644 index 0000000000000000000000000000000000000000..27669e55f99209b321a336c0dbac83e1b18ec9a1 --- /dev/null +++ b/HybridTensor/modules/references/MLP.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +try: + import AmpSelGemm +except ImportError: + AmpSelGemm = None + +from HybridTensor.triton.gather_gemm_col import gather_matmul_col +from HybridTensor.triton.gather_gemm_row import gather_matmul_row +from HybridTensor.utils.utils import arg_parser, sparse_index, create_results_directory +from HybridTensor.utils.profiling import benchmark_mlp_fwd, generate_index_sizes, save_results_to_csv, plot_results + +import pandas as pd +import matplotlib.pyplot as plt + +from tqdm import tqdm # For progress bars + +# implement standard MLP block with ReLU activation +class StandardMLPBlock(nn.Module): + def __init__(self, in_features, hidden_features=None, bias=False, device='cuda'): + super(StandardMLPBlock, self).__init__() + hidden_features = hidden_features or in_features*4 + + # this is stored in correct order; don't need to transpose for the CUTLASS kernel + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device, dtype=torch.float16) + + # this is stored in row major; need to transpose for the CUTLASS kernel + self.fc2 = nn.Linear(hidden_features, in_features, bias=bias, device=device, dtype=torch.float16) + self.relu = nn.ReLU() + + def forward(self, x): + # fc1 : d x (4d) + out = self.fc1(x) + # B x (4d) + out = self.relu(out) + + #fc2: (4d) x d + out = self.fc2(out) + return out + + def sparsify(self, zero_index): + self.fc1.weight.data[zero_index, :] = 0.0 + self.fc2.weight.data[:, zero_index] = 0.0 + +class SelectiveMLP(nn.Module): + def __init__(self, in_features, hidden_features=None, bias=False, requires_grad = False, device='cuda', dtype=torch.float16, activation='relu'): + super(SelectiveMLP, self).__init__() + if hidden_features is None: + hidden_features = in_features*4 + + factory_kwargs = {'device': torch.device(device), 'dtype': dtype} + self.fc1_w = torch.empty((hidden_features, in_features), requires_grad=requires_grad, **factory_kwargs) + self.fc2_w = torch.empty((hidden_features, in_features), requires_grad=requires_grad, **factory_kwargs) + self.act = nn.ReLU() + + def forward(self, x, index_vec): + index_size = index_vec.size(0) + #AmpSelGemm.run(A=A, B=B_col_major, index_vec= index_vec, M= M, N=index_size, K=K, index_size=index_size) + out = AmpSelGemm.run_col(A=x, B=self.fc1_w, index_vec= index_vec, M = x.size(0), N = index_size, K = self.fc1_w.size(1), index_size=index_size) + out = self.act(out) # need to fuse this with fc1 in the next iteration + out = AmpSelGemm.run_row1(A=out, B=self.fc2_w, index_vec= index_vec, M = x.size(0), N = self.fc2_w.size(1), K = index_size, index_size=index_size) + return out + + def load_from_MLP(self, mlp): + self.fc1_w = mlp.fc1.weight + self.fc2_w = mlp.fc2.weight.t().contiguous() + return self + +class SelectiveMLPTriton(SelectiveMLP): + def __init__(self, in_features, hidden_features=None, bias=False, requires_grad = False, device='cuda', dtype=torch.float16, activation='relu'): + super(SelectiveMLPTriton, self).__init__(in_features, hidden_features, bias, requires_grad, device, dtype, activation) + self.activation = activation + + def forward(self, x, index_vec): + out = gather_matmul_col(x, self.fc1_w, index_vec, activations=self.activation) + out = gather_matmul_row(out, self.fc2_w, index_vec) + return out + +def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, bias1 = None, bias2 = None, activation='relu'): + out = gather_matmul_col(x, fc1_w, index_vec, bias = bias1, activations=activation) + out = gather_matmul_row(out, fc2_w, index_vec, bias = bias2) + return out + +def profile_mlps(args, index_size): + # Create standard MLP block + standardmlp = StandardMLPBlock(args.in_features, args.hidden_features) + if AmpSelGemm is not None: + selectiveMLP = SelectiveMLP(args.in_features, args.hidden_features).load_from_MLP(standardmlp) + selectiveMLPTriton = SelectiveMLPTriton(args.in_features, args.hidden_features).load_from_MLP(standardmlp) + + # Test input + x = torch.randn(args.batch_size, args.in_features, dtype=torch.float16, device='cuda') + index_vec, zero_index = sparse_index(index_size, args.hidden_features or args.in_features*4) + standardmlp.sparsify(zero_index) + + # Measure execution time + std_out, standardmlp_time = benchmark_mlp_fwd(x, standardmlp, index_vec=None, iterations=args.iterations, print_result=False) + if AmpSelGemm is not None: + cutlass_out, selectiveMLP_time = benchmark_mlp_fwd(x, selectiveMLP, index_vec=index_vec, iterations=args.iterations, print_result=False) + triton_out, selectiveMLPTriton_time = benchmark_mlp_fwd(x, selectiveMLPTriton, index_vec=index_vec, iterations=args.iterations, print_result=False) + + # Optionally check results + if args.check_results: + print('Standard MLP output:', std_out[0]) + if AmpSelGemm is not None: + print('Selective MLP Cutlass output:', cutlass_out) + print('Selective MLP Triton output:', triton_out) + + # Calculate speedups + triton_speedup = standardmlp_time / selectiveMLPTriton_time if selectiveMLPTriton_time > 0 else float('inf') + if AmpSelGemm is not None: + cutlass_speedup = standardmlp_time / selectiveMLP_time if selectiveMLP_time > 0 else float('inf') + + return { + 'index_size': index_size, + 'standard_time': standardmlp_time, + 'selective_cutlass_time': selectiveMLP_time, + 'selective_triton_time': selectiveMLPTriton_time, + 'cutlass_speedup': cutlass_speedup, + 'triton_speedup': triton_speedup + } + +def run_profiling_over_index_sizes(args, index_sizes): + results = [] + for size in tqdm(index_sizes, desc="Profiling MLPs"): + result = profile_mlps(args, size) + results.append(result) + return pd.DataFrame(results) + + +''' +if __name__ == '__main__': + args = arg_parser() + + index_sizes = generate_index_sizes(args.hidden_features) + + # create standard MLP block + standardmlp = StandardMLPBlock(args.in_features, args.hidden_features) + selectiveMLP = SelectiveMLP(args.in_features, args.hidden_features).load_from_MLP(standardmlp) + selectiveMLPTriton = SelectiveMLPTriton(args.in_features, args.hidden_features).load_from_MLP(standardmlp) + + + # test input + x = torch.randn(args.batch_size, args.in_features, dtype=torch.float16, device='cuda') + index_vec, zero_index = sparse_index(args.index_size, args.hidden_features or args.in_features*4) + standardmlp.sparsify(zero_index) + + # measure execution time + std_out, standardmlp_time = benchmark_mlp_fwd(x, standardmlp, index_vec= None, iterations=args.iterations, print_result=True) + cutlass_out, selectiveMLP_time = benchmark_mlp_fwd(x, selectiveMLP, index_vec= index_vec, iterations=args.iterations, print_result=True) + triton_out, selectiveMLPTriton_time = benchmark_mlp_fwd(x, selectiveMLPTriton, index_vec= index_vec, iterations=args.iterations, print_result=True) + + if args.check_results: + print('Standard MLP output:', std_out[0]) + print('Selective MLP Cutlass output:', cutlass_out) + print('Selective MLP Triton output:', triton_out) + + triton_speedup = standardmlp_time/selectiveMLPTriton_time + cutlass_speedup = standardmlp_time/selectiveMLP_time + + print(f"Speedup of Cutlass implementation over standard MLP: {cutlass_speedup}") + print(f"Speedup of Triton implementation over standard MLP: {triton_speedup}") + +''' + +if __name__ == '__main__': + args = arg_parser() + + print(f"Profiling MLPs") + # Results directory + results_dir = create_results_directory(args.results_dir) + + # Define the range of index sizes you want to profile + index_sizes = generate_index_sizes(args.hidden_features) + + # Run profiling over different index sizes + profiling_results = run_profiling_over_index_sizes(args, index_sizes) + + # Save the results to a CSV file + save_results_to_csv(profiling_results, filename_prefix='mlp_profiling_results', results_dir=results_dir) + + # Plot the results + plot_results(profiling_results, output_prefix='mlp_profiling', results_dir=results_dir) + + # Optionally, print the DataFrame + if args.check_results: + print(profiling_results) \ No newline at end of file diff --git a/HybridTensor/modules/references/__pycache__/MLP.cpython-310.pyc b/HybridTensor/modules/references/__pycache__/MLP.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18e7d5381bf54e7b9d89cd107981cdbf3a896502 Binary files /dev/null and b/HybridTensor/modules/references/__pycache__/MLP.cpython-310.pyc differ diff --git a/HybridTensor/modules/references/__pycache__/MLP.cpython-39.pyc b/HybridTensor/modules/references/__pycache__/MLP.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80b785e47ee7e8f385519cc2379612037b467ac2 Binary files /dev/null and b/HybridTensor/modules/references/__pycache__/MLP.cpython-39.pyc differ diff --git a/HybridTensor/modules/references/__pycache__/fused_dense.cpython-310.pyc b/HybridTensor/modules/references/__pycache__/fused_dense.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f11a33b2add680dbdcae91c80d141f8f66b389 Binary files /dev/null and b/HybridTensor/modules/references/__pycache__/fused_dense.cpython-310.pyc differ diff --git a/HybridTensor/modules/references/__pycache__/fused_dense.cpython-39.pyc b/HybridTensor/modules/references/__pycache__/fused_dense.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54568e48e41ad993e29ce2bb547b15a765329140 Binary files /dev/null and b/HybridTensor/modules/references/__pycache__/fused_dense.cpython-39.pyc differ diff --git a/HybridTensor/modules/references/__pycache__/mha_dejavu.cpython-310.pyc b/HybridTensor/modules/references/__pycache__/mha_dejavu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0ea3444486df38116325bd74bf9d9bf3bebddfc Binary files /dev/null and b/HybridTensor/modules/references/__pycache__/mha_dejavu.cpython-310.pyc differ diff --git a/HybridTensor/modules/references/__pycache__/mha_dejavu.cpython-39.pyc b/HybridTensor/modules/references/__pycache__/mha_dejavu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec79e604b6e119bb2385e589e5ab118ace726513 Binary files /dev/null and b/HybridTensor/modules/references/__pycache__/mha_dejavu.cpython-39.pyc differ diff --git a/HybridTensor/modules/references/block_dejavu.py b/HybridTensor/modules/references/block_dejavu.py new file mode 100644 index 0000000000000000000000000000000000000000..a10b3bae9b7e73ae0750487faa6a5f9af08c854e --- /dev/null +++ b/HybridTensor/modules/references/block_dejavu.py @@ -0,0 +1,661 @@ +from typing import Optional +from functools import partial + +import torch +import torch.nn as nn +from torch import Tensor + +from torchvision.ops import StochasticDepth + +from flash_attn.modules.mha import MHA +from flash_attn.modules.mlp import Mlp + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm +except ImportError: + dropout_add_layer_norm = None + + +class Block(nn.Module): + def __init__( + self, + dim, + mixer_cls=None, + mlp_cls=None, + norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, + prenorm=True, + resid_dropout1=0.0, + resid_dropout2=0.0, + drop_path1=0.0, + drop_path2=0.0, + fused_dropout_add_ln=False, + return_residual=False, + residual_in_fp32=False, + sequence_parallel=False, + mark_shared_params=False, + ): + """ + For prenorm=True, this Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both + the hidden_states (output of the MLP) and the residual. + This is for performance reasons, as we can fuse the dropout, add and LayerNorm. + The residual needs to be provided (except for the very first block). + + For prenorm=False, this Block has the same structure as a regular postnorm Transformer + block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. + + return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. + This is for performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + super().__init__() + self.prenorm = prenorm + self.fused_dropout_add_ln = fused_dropout_add_ln + self.return_residual = return_residual + self.residual_in_fp32 = residual_in_fp32 + if self.residual_in_fp32: + assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" + if mixer_cls is None: + mixer_cls = partial(MHA, num_heads=dim // 64) + if mlp_cls is None: + mlp_cls = partial(Mlp, hidden_features=4 * dim) + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout1) + self.drop_path1 = StochasticDepth(drop_path1, mode="row") + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + if not isinstance(self.mlp, nn.Identity): + self.dropout2 = dropout_cls(resid_dropout2) + self.drop_path2 = StochasticDepth(drop_path2, mode="row") + self.norm2 = norm_cls(dim) + + if self.fused_dropout_add_ln: + assert dropout_add_layer_norm is not None, "dropout_add_ln is not installed" + assert isinstance(self.norm1, nn.LayerNorm) and isinstance( + self.dropout1, nn.Dropout + ) + + # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, + # then the input to each worker in the tensor parallel group will be different. + # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. + # For now this is not an issue because we always use sequence_parallel=True during training + # and only use sequence_parallel=False during inference. + + # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. + if sequence_parallel: + for p in self.norm1.parameters(): + p._sequence_parallel = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._sequence_parallel = True + # Mark the norm parameters as "shared_params" so that we sync their values at init. + if mark_shared_params: + for p in self.norm1.parameters(): + p._shared_params = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._shared_params = True + + def forward( + self, + hidden_states: Tensor, + residual: Optional[Tensor] = None, + mixer_subset=None, + mixer_kwargs=None, + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + """ + + if self.prenorm: + if not self.fused_dropout_add_ln: + dropped = self.drop_path1(self.dropout1(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + hidden_states, residual = dropout_add_layer_norm( + hidden_states, + residual, + self.norm1.weight, + self.norm1.bias, + self.dropout1.p if self.training else 0.0, + self.norm1.eps, + rowscale=rowscale1, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + ) + + if mixer_kwargs is None: + mixer_kwargs = {} + if mixer_subset is not None: + mixer_kwargs["mixer_subset"] = mixer_subset + + hidden_states = self.mixer(hidden_states, **mixer_kwargs) + + if mixer_subset is not None: + residual = residual[:, mixer_subset] + + if not isinstance(self.mlp, nn.Identity): + if not self.fused_dropout_add_ln: + dropped = self.drop_path2(self.dropout2(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm2( + residual.to(dtype=self.norm2.weight.dtype) + ) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + hidden_states, residual = dropout_add_layer_norm( + hidden_states, + residual, + self.norm2.weight, + self.norm2.bias, + self.dropout2.p if self.training else 0.0, + self.norm2.eps, + rowscale=rowscale2, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + ) + + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + else: + assert residual is None + mixer_out = self.mixer( + hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) + ) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + if not self.fused_dropout_add_ln: + hidden_states = self.norm1( + (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to( + dtype=self.norm1.weight.dtype + ) + ) + else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1( + torch.ones( + mixer_out.shape[:-1], + device=mixer_out.device, + dtype=mixer_out.dtype, + ) + ) + hidden_states = dropout_add_layer_norm( + mixer_out, + hidden_states, + self.norm1.weight, + self.norm1.bias, + self.dropout1.p if self.training else 0.0, + self.norm1.eps, + rowscale=rowscale1, + prenorm=False, + ) + if not isinstance(self.mlp, nn.Identity): + mlp_out = self.mlp(hidden_states) + if self.return_residual: # mlp out is actually a pair here + mlp_out, hidden_states = mlp_out + if not self.fused_dropout_add_ln: + hidden_states = self.norm2( + (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to( + dtype=self.norm2.weight.dtype + ) + ) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones( + mlp_out.shape[:-1], + device=mlp_out.device, + dtype=mlp_out.dtype, + ) + ) + hidden_states = dropout_add_layer_norm( + mlp_out, + hidden_states, + self.norm2.weight, + self.norm2.bias, + self.dropout2.p if self.training else 0.0, + self.norm2.eps, + rowscale=rowscale2, + prenorm=False, + ) + return hidden_states + + +class BlockDejavu(nn.Module): + def __init__( + self, + dim, + mixer_cls=None, + mlp_cls=None, + norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, + prenorm=True, + resid_dropout1=0.0, + resid_dropout2=0.0, + drop_path1=0.0, + drop_path2=0.0, + fused_dropout_add_ln=False, + return_residual=False, + residual_in_fp32=False, + sequence_parallel=False, + mark_shared_params=False, + ): + """ + For prenorm=True, this Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both + the hidden_states (output of the MLP) and the residual (input of the MLP). + This is for performance reasons, as we can fuse the dropout, add and LayerNorm. + The residual needs to be provided (except for the very first block). + + For prenorm=False, this Block has the same structure as a regular postnorm Transformer + block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. + + return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. + This is for performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + super().__init__() + assert prenorm == True, "Dejavu only support prenorm for now" + self.prenorm = prenorm + self.fused_dropout_add_ln = fused_dropout_add_ln + self.return_residual = return_residual + self.residual_in_fp32 = residual_in_fp32 + if self.residual_in_fp32: + assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" + if mixer_cls is None: + mixer_cls = partial(MHA, num_heads=dim // 64) + if mlp_cls is None: + mlp_cls = partial(Mlp, hidden_features=4 * dim) + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout1) + self.drop_path1 = StochasticDepth(drop_path1, mode="row") + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + if not isinstance(self.mlp, nn.Identity): + self.dropout2 = dropout_cls(resid_dropout2) + self.drop_path2 = StochasticDepth(drop_path2, mode="row") + self.norm2 = norm_cls(dim) + + if self.fused_dropout_add_ln: + assert dropout_add_layer_norm is not None, "dropout_add_ln is not installed" + assert isinstance(self.norm1, nn.LayerNorm) and isinstance( + self.dropout1, nn.Dropout + ) + + # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, + # then the input to each worker in the tensor parallel group will be different. + # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. + # For now this is not an issue because we always use sequence_parallel=True during training + # and only use sequence_parallel=False during inference. + + # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. + if sequence_parallel: + for p in self.norm1.parameters(): + p._sequence_parallel = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._sequence_parallel = True + # Mark the norm parameters as "shared_params" so that we sync their values at init. + if mark_shared_params: + for p in self.norm1.parameters(): + p._shared_params = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._shared_params = True + + # sp related + self.mixer.mlp_k = self.mlp.sp.K + + def forward( + self, + hidden_states: Tensor, + residual: Optional[Tensor] = None, + mixer_subset=None, + mixer_kwargs=None, + mlp_sp_logit=None, + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: If prenorm, residual is MLP block input before layer norm. + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + """ + + if self.prenorm: + if not self.fused_dropout_add_ln: + dropped = self.drop_path1(self.dropout1(hidden_states)) + residual = ( + (dropped + residual) if residual is not None else dropped + ) # residual is output of MLP and input of MLP -> input of attention block + hidden_states = self.norm1( + residual.to(dtype=self.norm1.weight.dtype) + ) # hidden states is layer norm of attention block input + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + hidden_states, residual = dropout_add_layer_norm( + hidden_states, + residual, + self.norm1.weight, + self.norm1.bias, + self.dropout1.p if self.training else 0.0, + self.norm1.eps, + rowscale=rowscale1, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + ) + + if mixer_kwargs is None: + mixer_kwargs = {} + if mixer_subset is not None: + mixer_kwargs["mixer_subset"] = mixer_subset + + hidden_states, mlp_idx = self.mixer( + hidden_states, # hidden states is after layer norm + mlp_sp_logit=mlp_sp_logit, + **mixer_kwargs + ) + + if mixer_subset is not None: + residual = residual[:, mixer_subset] + + if not isinstance(self.mlp, nn.Identity): + if not self.fused_dropout_add_ln: + dropped = self.drop_path2( + self.dropout2(hidden_states) + ) # hidden states is output of MHA + residual = ( + (dropped + residual) if residual is not None else dropped + ) # residual is output of MHA + input of MHA + hidden_states = self.norm2( + residual.to(dtype=self.norm2.weight.dtype) + ) # hidden_states is layer norm of MLP input + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + hidden_states, residual = dropout_add_layer_norm( + hidden_states, + residual, + self.norm2.weight, + self.norm2.bias, + self.dropout2.p if self.training else 0.0, + self.norm2.eps, + rowscale=rowscale2, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + ) + # hidden states is after layer norm, residual is before layer norm + hidden_states, next_mlp_sp_logit = self.mlp( + hidden_states, residual, mlp_idx + ) + return ( + hidden_states, + residual, + next_mlp_sp_logit, + ) + + +class BlockMlpAttSparse(nn.Module): + def __init__( + self, + dim, + mixer_cls=None, + mlp_cls=None, + norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, + prenorm=True, + resid_dropout1=0.0, + resid_dropout2=0.0, + drop_path1=0.0, + drop_path2=0.0, + fused_dropout_add_ln=False, + return_residual=False, + residual_in_fp32=False, + sequence_parallel=False, + mark_shared_params=False, + ): + """ + For prenorm=True, this Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both + the hidden_states (output of the MLP) and the residual (input of the MLP). + This is for performance reasons, as we can fuse the dropout, add and LayerNorm. + The residual needs to be provided (except for the very first block). + + For prenorm=False, this Block has the same structure as a regular postnorm Transformer + block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. + + return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. + This is for performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + super().__init__() + assert prenorm == True, "Dejavu only support prenorm for now" + self.prenorm = prenorm + self.fused_dropout_add_ln = fused_dropout_add_ln + self.return_residual = return_residual + self.residual_in_fp32 = residual_in_fp32 + if self.residual_in_fp32: + assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" + if mixer_cls is None: + mixer_cls = partial(MHA, num_heads=dim // 64) + if mlp_cls is None: + mlp_cls = partial(Mlp, hidden_features=4 * dim) + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout1) + self.drop_path1 = StochasticDepth(drop_path1, mode="row") + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + if not isinstance(self.mlp, nn.Identity): + self.dropout2 = dropout_cls(resid_dropout2) + self.drop_path2 = StochasticDepth(drop_path2, mode="row") + self.norm2 = norm_cls(dim) + + if self.fused_dropout_add_ln: + assert dropout_add_layer_norm is not None, "dropout_add_ln is not installed" + assert isinstance(self.norm1, nn.LayerNorm) and isinstance( + self.dropout1, nn.Dropout + ) + + # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, + # then the input to each worker in the tensor parallel group will be different. + # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. + # For now this is not an issue because we always use sequence_parallel=True during training + # and only use sequence_parallel=False during inference. + + # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. + if sequence_parallel: + for p in self.norm1.parameters(): + p._sequence_parallel = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._sequence_parallel = True + # Mark the norm parameters as "shared_params" so that we sync their values at init. + if mark_shared_params: + for p in self.norm1.parameters(): + p._shared_params = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._shared_params = True + + # sp related + if hasattr(self.mlp, "sp"): + self.mixer.mlp_k = self.mlp.sp.K + if hasattr(self.mixer, "sp"): + self.mixer.att_k = self.mixer.sp.K + + def forward( + self, + hidden_states: Tensor, + residual: Optional[Tensor] = None, + mixer_subset=None, + mixer_kwargs=None, + head_idx=None, + mlp_sp_logit=None, + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: If prenorm, residual is MLP block input before layer norm. + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + """ + + if self.prenorm: + if not self.fused_dropout_add_ln: + dropped = self.drop_path1(self.dropout1(hidden_states)) + residual = ( + (dropped + residual) if residual is not None else dropped + ) # residual is output of MLP and input of MLP -> input of attention block + hidden_states = self.norm1( + residual.to(dtype=self.norm1.weight.dtype) + ) # hidden states is layer norm of attention block input + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + hidden_states, residual = dropout_add_layer_norm( + hidden_states, + residual, + self.norm1.weight, + self.norm1.bias, + self.dropout1.p if self.training else 0.0, + self.norm1.eps, + rowscale=rowscale1, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + ) + + if mixer_kwargs is None: + mixer_kwargs = {} + if mixer_subset is not None: + mixer_kwargs["mixer_subset"] = mixer_subset + + hidden_states, mlp_idx, next_att_idx = self.mixer( + hidden_states, # hidden states is after layer norm + residual=residual, + head_idx=head_idx, + mlp_sp_logit=mlp_sp_logit, + **mixer_kwargs + ) + + if mixer_subset is not None: + residual = residual[:, mixer_subset] + + if not isinstance(self.mlp, nn.Identity): + if not self.fused_dropout_add_ln: + dropped = self.drop_path2( + self.dropout2(hidden_states) + ) # hidden states is output of MHA + residual = ( + (dropped + residual) if residual is not None else dropped + ) # residual is output of MHA + input of MHA + hidden_states = self.norm2( + residual.to(dtype=self.norm2.weight.dtype) + ) # hidden_states is layer norm of MLP input + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + hidden_states, residual = dropout_add_layer_norm( + hidden_states, + residual, + self.norm2.weight, + self.norm2.bias, + self.dropout2.p if self.training else 0.0, + self.norm2.eps, + rowscale=rowscale2, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + ) + # hidden states is after layer norm, residual is before layer norm + hidden_states, next_mlp_sp_logit = self.mlp( + hidden_states, residual, mlp_idx + ) + return (hidden_states, residual, next_mlp_sp_logit, next_att_idx) diff --git a/HybridTensor/modules/references/fused_dense.py b/HybridTensor/modules/references/fused_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..15f79c78df43bddb590a671fdf76c33cf34abe99 --- /dev/null +++ b/HybridTensor/modules/references/fused_dense.py @@ -0,0 +1,904 @@ +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py + +from typing import Optional +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.cuda.amp import custom_bwd, custom_fwd + +# import fused_dense_cuda # from apex + +import fused_dense_lib as fused_dense_cuda + +# from flash_attn.ops.gelu_activation import gelu_bwd +from flash_attn.utils.distributed import ( + all_gather_raw, + reduce_scatter_raw, + all_reduce_raw, +) +from flash_attn.utils.distributed import reduce_scatter, all_reduce +from einops import rearrange + + +class FusedDenseFunc(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + weight, + bias, + return_residual=False, + process_group=None, + sequence_parallel=True, + ): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather_raw of x before doing the matmul. + """ + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = ( + bias.to(dtype=torch.get_autocast_gpu_dtype()) + if bias is not None + else None + ) + weight = weight.contiguous() + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight.shape) > 65535 * 32: + raise RuntimeError("fused_dense only supports matrix dims <= 2M") + output = F.linear(total_x, weight, bias) + if ctx.compute_weight_gradient: + ctx.save_for_backward(x, weight) + else: + ctx.save_for_backward(weight) + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + if ctx.compute_weight_gradient: + x, weight = ctx.saved_tensors + if process_group is not None and sequence_parallel: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + else: + (weight,) = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, weight.t()) + else: + grad_input = torch.addmm( + grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_output, + weight, + ) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn( + grad_input, process_group, async_op=True + ) + else: + grad_input = None + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + if process_group is not None and sequence_parallel: + handle_x.wait() + grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), + grad_output, + ctx.needs_input_grad[2], + ) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return grad_input, grad_weight, grad_bias, None, None, None + + +def fused_dense_func( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + return_residual: bool = False, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True, +): + dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( + x.dtype == torch.float32 and torch.is_autocast_enabled() + ) + if ( + x.is_cuda + and weight.is_cuda + and (bias is None or bias.is_cuda) + and dtype_eligible + ): + return FusedDenseFunc.apply( + x, weight, bias, return_residual, process_group, sequence_parallel + ) + else: + assert process_group is None + out = F.linear(x, weight, bias) + return out if not return_residual else (out, x) + + +class FusedDense(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + return_residual: bool = False, + device=None, + dtype=None, + ) -> None: + super().__init__( + in_features, out_features, bias=bias, device=device, dtype=dtype + ) + self.return_residual = return_residual + + def forward(self, x, process_group=None): + """ + If process_group is not None, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul. + """ + return fused_dense_func( + x, + self.weight, + self.bias, + return_residual=self.return_residual, + process_group=process_group, + ) + + +class ColumnParallelLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + if out_features % world_size != 0: + raise ValueError( + f"out_features ({out_features}) must be divisible by " + f"world_size ({world_size})" + ) + super().__init__( + in_features, + out_features // world_size, + bias=bias, + device=device, + dtype=dtype, + ) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + # we do an all_gather of x before doing the matmul. + # If not, then the input is already gathered. + return fused_dense_func( + x, + self.weight, + self.bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + ) + + +class RowParallelLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + rank = torch.distributed.get_rank(process_group) + if in_features % world_size != 0: + raise ValueError( + f"in_features ({in_features}) must be divisible by " + f"world_size ({world_size})" + ) + # Only rank 0 will have bias + super().__init__( + in_features // world_size, + out_features, + bias=bias and rank == 0, + device=device, + dtype=dtype, + ) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + """ + We're doing Tensor Parallel with sequence parallelism: we do the matmul and then + a reduce_scatter of the result. + """ + out = fused_dense_func(x, self.weight, self.bias) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) + + +class FusedMLPFunc(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + weight1, + bias1, + weight2, + bias2, + activation="gelu_approx", + save_pre_act=True, + return_residual=False, + checkpoint_lvl=0, + heuristic=0, + process_group=None, + sequence_parallel=True, + ): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather of x before doing the matmul. + If sequence_parallel=False, then the input is already gathered. + + checkpoint_lvl: + 0: no recomputation in the bwd + 1: recompute gelu_out / relu_out in the bwd + 2: recompute pre_act and gelu_out / relu_out in the bwd + """ + assert -1 <= heuristic <= 4 + assert activation in ["gelu_approx", "relu"] + if not save_pre_act: + checkpoint_lvl = 2 + assert checkpoint_lvl in [0, 1, 2] + ctx.return_residual = return_residual + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + ctx.checkpoint_lvl = checkpoint_lvl + ctx.activation = activation + ctx.heuristic = heuristic + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]] + bias1 = bias1.to(dtype=dtype) if bias1 is not None else None + bias2 = bias2.to(dtype=dtype) if bias2 is not None else None + weight1 = weight1.contiguous() + bias1 = bias1.contiguous() if bias1 is not None else None + weight2 = weight2.contiguous() + bias2 = bias2.contiguous() if bias2 is not None else None + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: + raise RuntimeError("fused_dense only supports matrix dims <= 2M") + if heuristic == -1: + pre_act = F.linear(total_x, weight1, bias1) + activation_fn = ( + partial(F.gelu, approximate="tanh") + if activation == "gelu_approx" + else F.relu + ) + output1 = activation_fn(pre_act) + # This is before adding bias1 + # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) + # with torch.jit.fuser('fuser2'): + # output1 = bias_gelu(pre_act, bias1) + else: + is_gelu = activation == "gelu_approx" + output1, *rest = fused_dense_cuda.linear_act_forward( + total_x.reshape(batch_dim, n), + weight1, + bias1, + is_gelu, + save_pre_act, + heuristic, + ) + if save_pre_act: + pre_act = rest[0] + output2 = F.linear(output1, weight2, bias2) + if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): + # For RELU the pre_act is very small (just a bit-mask) so we just save it + ctx.save_for_backward(x, weight1, weight2, pre_act, output1) + elif checkpoint_lvl == 1: + ctx.save_for_backward(x, weight1, weight2, pre_act) + elif checkpoint_lvl == 2: + ctx.save_for_backward(x, weight1, weight2, bias1) + output2 = output2.reshape(*batch_shape, output2.shape[-1]) + return output2 if not return_residual else (output2, x) + + + ''' + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + checkpoint_lvl = ctx.checkpoint_lvl + activation = ctx.activation + activation_fn = ( + partial(F.gelu, approximate="tanh") + if activation == "gelu_approx" + else F.relu + ) + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + x, weight1, weight2, *rest = ctx.saved_tensors + if process_group is None or not sequence_parallel: + total_x = x + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + if checkpoint_lvl in [0, 1]: + if process_group is not None and sequence_parallel: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): + pre_act, output1 = rest + elif checkpoint_lvl == 1: + (pre_act,) = rest + output1 = activation_fn(pre_act) + elif checkpoint_lvl == 2: + (bias1,) = rest + if process_group is not None and sequence_parallel: + total_x, _ = all_gather_raw(x, process_group) + if ctx.heuristic == -1: + pre_act = F.linear(total_x, weight1, bias1) + output1 = activation_fn(pre_act) + else: + output1, pre_act = fused_dense_cuda.linear_act_forward( + total_x.reshape(batch_dim, total_x.shape[-1]), + weight1, + bias1, + activation == "gelu_approx", + True, + ctx.heuristic, + ) + + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + output1 = output1.reshape(batch_dim, output1.shape[-1]) + pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1]) + if ctx.needs_input_grad[3]: + grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad( + output1, grad_output, ctx.needs_input_grad[4] + ) + else: + grad_weight2 = None + grad_bias2 = grad_output if ctx.needs_input_grad[4] else None + if ctx.heuristic == -1: + # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) + grad_output1 = F.linear(grad_output, weight2.t()) + with torch.jit.fuser("fuser2"): + activation_grad_fn = ( + gelu_bwd if activation == "gelu_approx" else relu_bwd + ) + grad_pre_act = activation_grad_fn(grad_output1, pre_act) + else: + # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't + # just compute gelu/relu grad + grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad( + weight2, + grad_output, + pre_act, + activation == "gelu_approx", + ctx.heuristic, + ) + if not ctx.needs_input_grad[2]: + grad_bias1 = None + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_pre_act, weight1.t()) + else: + grad_input = torch.addmm( + grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_pre_act, + weight1, + ) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn( + grad_input, process_group, async_op=True + ) + else: + grad_input = None + if ctx.heuristic == -1: + if ctx.needs_input_grad[1]: + if process_group is not None and sequence_parallel: + handle_x.wait() + grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), + grad_pre_act, + ctx.needs_input_grad[2], + ) + else: + grad_weight1 = None + grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None + else: + if ctx.needs_input_grad[1]: + if process_group is not None and sequence_parallel: + handle_x.wait() + grad_weight1 = F.linear( + grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t() + ) + else: + grad_weight1 = None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return ( + grad_input, + grad_weight1, + grad_bias1, + grad_weight2, + grad_bias2, + None, + None, + None, + None, + None, + None, + None, + ) + + ''' + + +def fused_mlp_func( + x: Tensor, + weight1: Tensor, + weight2: Tensor, + bias1: Optional[Tensor] = None, + bias2: Optional[Tensor] = None, + activation: str = "gelu_approx", + save_pre_act: bool = True, + return_residual: bool = False, + checkpoint_lvl: int = 0, + heuristic: int = 0, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True, +): + assert activation in ["gelu_approx", "relu"] + dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( + x.dtype == torch.float32 and torch.is_autocast_enabled() + ) + # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) + dim_eligible = not save_pre_act or ( + x.shape[-1] % (128 if activation == "relu" else 8) == 0 + ) + if ( + x.is_cuda + and weight1.is_cuda + and weight2.is_cuda + and (bias1 is None or bias1.is_cuda) + and (bias2 is None or bias2.is_cuda) + and dtype_eligible + and dim_eligible + ): + return FusedMLPFunc.apply( + x, + weight1, + bias1, + weight2, + bias2, + activation, + save_pre_act, + return_residual, + checkpoint_lvl, + heuristic, + process_group, + sequence_parallel, + ) + else: + assert process_group is None + pre_act = F.linear(x, weight1, bias1) + activation_fn = ( + partial(F.gelu, approximate="tanh") + if activation == "gelu_approx" + else partial(F.relu, inplace=True) + ) + output1 = activation_fn(pre_act) + output2 = F.linear(output1, weight2, bias2) + return output2 if not return_residual else (output2, x) + + +class FusedMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features, + out_features=None, + bias1=True, + bias2=True, + activation="gelu_approx", + return_residual=False, + checkpoint_lvl=0, + heuristic="auto", + layer_idx=None, + device=None, + dtype=None, + ): + """ + If process_group is not None, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul, gelu, then matmul. + Finally we do a reduce_scatter of the output. + + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute pre_act and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + 'auto': heuristic will be picked automatically: + For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + assert checkpoint_lvl in [0, 1, 2] + assert activation in ["gelu_approx", "relu"] + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if out_features is None: + out_features = in_features + self.activation = activation + self.return_residual = return_residual + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) + self.fc2 = nn.Linear( + hidden_features, out_features, bias=bias2, **factory_kwargs + ) + + def forward(self, x, process_group=None): + dtype = ( + x.dtype + if not torch.is_autocast_enabled() + else torch.get_autocast_gpu_dtype() + ) + if self.heuristic == "auto": + if self.activation == "gelu_approx": + cuda_ver = tuple(map(int, torch.version.cuda.split("."))) + heuristic = ( + 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + ) + else: + heuristic = 0 + else: + heuristic = self.heuristic + out = fused_mlp_func( + x, + self.fc1.weight, + self.fc2.weight, + self.fc1.bias, + self.fc2.bias, + activation=self.activation, + save_pre_act=self.training, + return_residual=self.return_residual, + checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, + process_group=process_group, + ) + if self.return_residual: + out, x = out + if process_group is not None: + out = reduce_scatter(out, process_group) + return out if not self.return_residual else (out, x) + + +class ParallelFusedMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features, + out_features=None, + activation="gelu_approx", + process_group: ProcessGroup = None, + bias1=True, + bias2=True, + sequence_parallel=True, + layer_idx=None, + checkpoint_lvl=0, + heuristic="auto", + device=None, + dtype=None, + sp_kwargs=None, + ): + """ + process_group is required. We're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul, gelu, then matmul. + Finally we do a reduce_scatter of the output. + + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute pre_act and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + 'auto': heuristic will be picked automatically: + For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. + """ + assert checkpoint_lvl in [0, 1, 2] + assert activation in ["gelu_approx", "relu"] + assert process_group is not None + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if out_features is None: + out_features = in_features + self.activation = activation + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic + self.fc1 = ColumnParallelLinear( + in_features, hidden_features, process_group, bias=bias1, **factory_kwargs + ) + self.fc2 = RowParallelLinear( + hidden_features, out_features, process_group, bias=bias2, **factory_kwargs + ) + + def forward(self, x): + if self.heuristic == "auto": + dtype = ( + x.dtype + if not torch.is_autocast_enabled() + else torch.get_autocast_gpu_dtype() + ) + if self.activation == "gelu_approx": + cuda_ver = tuple(map(int, torch.version.cuda.split("."))) + heuristic = ( + 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + ) + else: + heuristic = 0 + else: + heuristic = self.heuristic + + out = fused_mlp_func( + x, + self.fc1.weight, + self.fc2.weight, + self.fc1.bias, + self.fc2.bias, + activation=self.activation, + save_pre_act=self.training, + checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + ) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) + + +class RowParallelLinearNoReduce(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + rank = torch.distributed.get_rank(process_group) + if in_features % world_size != 0: + raise ValueError( + f"in_features ({in_features}) must be divisible by " + f"world_size ({world_size})" + ) + # Only rank 0 will have bias + super().__init__( + in_features // world_size, + out_features, + bias=bias and rank == 0, + device=device, + dtype=dtype, + ) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + """ + We're doing Tensor Parallel with sequence parallelism: we do the matmul and then + a reduce_scatter of the result. + """ + out = fused_dense_func(x, self.weight, self.bias) + return out + + +class ParallelFusedMLPDejavu(nn.Module): + def __init__( + self, + in_features, + hidden_features, + sp_kwargs, + out_features=None, + activation="gelu_approx", + layer_idx=None, + process_group: ProcessGroup = None, + bias1=True, + bias2=True, + sequence_parallel=True, + checkpoint_lvl=0, + heuristic="auto", + device=None, + dtype=None, + ): + """ + process_group is required. We're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul, gelu, then matmul. + Finally we do a reduce_scatter of the output. + + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute pre_act and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + 'auto': heuristic will be picked automatically: + For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. + """ + assert checkpoint_lvl in [0, 1, 2] + assert activation in ["gelu_approx", "relu"] + assert process_group is not None + assert sp_kwargs != None, "sparse predictor parameters are not passed in." + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if out_features is None: + out_features = in_features + self.activation = activation + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic + self.fc1 = ColumnParallelLinear( + in_features, hidden_features, process_group, bias=bias1, **factory_kwargs + ) + self.fc2 = RowParallelLinear( + hidden_features, out_features, process_group, bias=bias2, **factory_kwargs + ) + self.layer_idx = layer_idx + # sparse related + self.fc2_weight_t = self.register_buffer("fc2_weigth_t", None) + self.sp = ParallelSP( + layer_idx=layer_idx, + device=device, + dtype=dtype, + process_group=process_group, + sequence_parallel=sequence_parallel, + **sp_kwargs, + ) + self.sp_stream = torch.cuda.Stream(device="cuda", priority=0) + self.event_mlp = torch.cuda.Event(enable_timing=False, blocking=False) + self.event_mlp_sp = torch.cuda.Event(enable_timing=False, blocking=False) + + def forward(self, x, residual, idx=None): + if self.heuristic == "auto": + dtype = ( + x.dtype + if not torch.is_autocast_enabled() + else torch.get_autocast_gpu_dtype() + ) + if self.activation == "gelu_approx": + cuda_ver = tuple(map(int, torch.version.cuda.split("."))) + heuristic = ( + 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + ) + else: + heuristic = 0 + else: + heuristic = self.heuristic + curr_stream = torch.cuda.current_stream() + do_token_generation = x.size(1) == 1 + if idx != None: + assert x.size(1) == 1 + from einops import rearrange + from src.ops.triton.gather_gemv import mlp_sparse + + if self.fc2_weight_t is None: + self.fc2_weight_t = self.fc2.weight.t().contiguous() + out = mlp_sparse( + rearrange(x, "b 1 d -> b d"), + self.fc1.weight, + self.fc2_weight_t, + idx, + self.fc1.bias, + self.fc2.bias, + ) + out = rearrange(out, "b d -> b 1 d") + else: + out = fused_mlp_func( + x, + self.fc1.weight, + self.fc2.weight, + self.fc1.bias, + self.fc2.bias, + activation=self.activation, + save_pre_act=self.training, + checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + ) + curr_stream.record_event(self.event_mlp) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + + out = reduce_fn(out, self.process_group) + + with torch.cuda.stream(self.sp_stream): + self.sp_stream.wait_event(self.event_mlp) + mlp_logit = None + if do_token_generation: + mlp_logit = self.sp(residual) + self.sp_stream.record_event(self.event_mlp_sp) + + curr_stream.wait_event(self.event_mlp_sp) + + return out, mlp_logit + \ No newline at end of file diff --git a/HybridTensor/modules/references/mha_dejavu.py b/HybridTensor/modules/references/mha_dejavu.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb0681695c5c9891523c48b84b3c21399cda320 --- /dev/null +++ b/HybridTensor/modules/references/mha_dejavu.py @@ -0,0 +1,1506 @@ +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func + from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func +except ImportError: + flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func = None, None + +try: + from flash_attn.ops.flash_attn_triton import ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + ) +except ImportError: + flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None + +try: + from src.ops.fused_dense_sparse_dejavu import ( + FusedDense, + ColumnParallelLinear, + RowParallelLinear, + RowParallelLinearNoReduce, + ) +except ImportError: + FusedDense, ColumnParallelLinear, RowParallelLinear, RowParallelLinearNoReduce = ( + None, + None, + None, + None, + ) + +try: + from flash_attn.layers.rotary import RotaryEmbedding +except ImportError: + RotaryEmbedding = None + +try: + import ft_attention +except ImportError: + ft_attention = None + +from flash_attn.utils.distributed import all_reduce + + +class FlashSelfAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__( + self, causal=False, softmax_scale=None, attention_dropout=0.0, triton=False + ): + super().__init__() + if attention_dropout != 0.0 or not triton: + assert ( + flash_attn_unpadded_qkvpacked_func is not None + ), "FlashAttention is not installed" + if attention_dropout == 0.0 and triton: + assert ( + flash_attn_qkvpacked_func is not None + ), "FlashAttention Triton is not installed" + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + self.triton = triton + + def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. + If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). + If cu_seqlens is not None and max_seqlen is not None, then qkv has shape + (total, 3, H, D), where total is the sum of the sequence lengths in the batch. + causal: if passed, will override self.causal + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into qkv. + max_seqlen: int. Maximum sequence length in the batch. + Returns: + -------- + out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, + else (B, S, H, D). + """ + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + causal = self.causal if causal is None else causal + unpadded = cu_seqlens is not None + if unpadded: + assert cu_seqlens.dtype == torch.int32 + assert max_seqlen is not None + assert isinstance(max_seqlen, int) + return flash_attn_unpadded_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal=causal, + ) + else: + batch_size, seqlen = qkv.shape[0], qkv.shape[1] + # Triton version doesn't support dropout + if self.triton and (self.dropout_p == 0 or not self.training): + output = flash_attn_qkvpacked_func( + qkv, None, causal, self.softmax_scale + ) + else: + qkv = rearrange(qkv, "b s ... -> (b s) ...") + max_seqlen = seqlen + cu_seqlens = torch.arange( + 0, + (batch_size + 1) * seqlen, + step=seqlen, + dtype=torch.int32, + device=qkv.device, + ) + output = flash_attn_unpadded_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal=causal, + ) + output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) + return output + + +class FlashCrossAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__( + self, causal=False, softmax_scale=None, attention_dropout=0.0, triton=False + ): + super().__init__() + if attention_dropout != 0.0 or not triton: + assert ( + flash_attn_unpadded_kvpacked_func is not None + ), "FlashAttention is not installed" + if attention_dropout == 0.0 and triton: + assert ( + flash_attn_kvpacked_func is not None + ), "FlashAttention Triton is not installed" + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + self.triton = triton + + def forward( + self, + q, + kv, + causal=None, + cu_seqlens=None, + max_seqlen=None, + cu_seqlens_k=None, + max_seqlen_k=None, + ): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, Sq, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H, D) + causal: if passed, will override self.causal + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + max_seqlen: int. Maximum sequence length in the batch of q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_k: int. Maximum sequence length in the batch of k and v. + """ + assert q.dtype in [torch.float16, torch.bfloat16] + assert q.is_cuda and kv.is_cuda + causal = self.causal if causal is None else causal + unpadded = cu_seqlens is not None + if unpadded: + assert cu_seqlens.dtype == torch.int32 + assert max_seqlen is not None + assert isinstance(max_seqlen, int) + assert cu_seqlens_k is not None + assert cu_seqlens_k.dtype == torch.int32 + assert max_seqlen_k is not None + assert isinstance(max_seqlen, int) + return flash_attn_unpadded_kvpacked_func( + q, + kv, + cu_seqlens, + cu_seqlens_k, + max_seqlen, + max_seqlen_k, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal=causal, + ) + else: + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = kv.shape[1] + assert ( + kv.shape[0] == batch_size + and kv.shape[3] == q.shape[2] + and kv.shape[4] == q.shape[3] + ) + if self.triton and ( + self.dropout_p == 0.0 or not self.training + ): # Triton version doesn't support dropout + output = flash_attn_kvpacked_func( + q, kv, None, causal, self.softmax_scale + ) + else: + q = rearrange(q, "b s ... -> (b s) ...") + kv = rearrange(kv, "b s ... -> (b s) ...") + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q.device, + ) + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=kv.device, + ) + output = flash_attn_unpadded_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + seqlen_q, + seqlen_k, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal=causal, + ) + output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) + return output + +class SelfAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv, causal=None, key_padding_mask=None): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) + causal: if passed, will override self.causal + key_padding_mask: boolean mask to apply to the attention weights. True means to keep, + False means to mask out. (B, S) + """ + batch_size, seqlen = qkv.shape[0], qkv.shape[1] + causal = self.causal if causal is None else causal + q, k, v = qkv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full( + (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device + ) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu( + torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 + ) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + return output + +class CrossAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, kv, causal=None, key_padding_mask=None): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, Sq, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H, D) + causal: if passed, will override self.causal + key_padding_mask: boolean mask to apply to the attention weights. True means to keep, + False means to mask out. (B, Sk) + """ + batch_size, seqlen_q = q.shape[0], q.shape[1] + causal = self.causal if causal is None else causal + seqlen_k = kv.shape[1] + assert ( + kv.shape[0] == batch_size + and kv.shape[3] == q.shape[2] + and kv.shape[4] == q.shape[3] + ) + k, v = kv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full( + (batch_size, seqlen_k), + -10000.0, + dtype=scores.dtype, + device=scores.device, + ) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu( + torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1 + ) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + return output + +class LinearResidual(nn.Linear): + """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input), input + + +def _update_kv_cache(kv, inference_params, layer_idx): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + # Pre-allocate memory for key-values for inference. + num_heads, head_dim = kv.shape[-2:] + if layer_idx not in inference_params.key_value_memory_dict: + kv_cache = torch.empty( + inference_params.max_batch_size, + inference_params.max_sequence_len, + 2, + num_heads, + head_dim, + dtype=kv.dtype, + device=kv.device, + ) + inference_params.key_value_memory_dict[layer_idx] = kv_cache + else: + if not inference_params.fused_ft_kernel: + kv_cache = inference_params.key_value_memory_dict[layer_idx] + else: + # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) + # where packsize = 4 if fp32, 8 if fp16 or bf16. + # v_cache has shape (b, h, s, headdim) + k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] + kv_cache = None + # Adjust key and value for inference + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + kv.shape[1] + assert batch_end <= ( + kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] + ) + assert sequence_end <= ( + kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] + ) + # Copy key and values. + if not inference_params.fused_ft_kernel: + assert kv_cache is not None + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + return kv + else: + assert inference_params.sequence_len_offset == 0 + # FT kernel requires different layouts for the k_cache and v_cache. + assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if kv.dtype == torch.float32 else 8 + if kv_cache is not None: + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + k_cache = rearrange( + kv_cache[:, :, 0], + "b s h (d packsize) -> b h d s packsize", + packsize=packsize, + ).contiguous() + v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous() + inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) + else: + k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( + kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize + ) + v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange( + kv[:, :, 1], "b s h d -> b h s d" + ) + return kv + + +def split_topk(input, k, sorted=False, num_chunks=2): + N0, N1 = list(input.shape) + NN0, NN1 = N0 * num_chunks, N1 // num_chunks + part_input = input.reshape([NN0, NN1]) + part_val, part_idx = torch.topk(part_input, k=k, dim=1) + key = (N0, N1) + if not key in self.part_idx_offset: + pio = ( + torch.tensor( + [NN1 * i for i in range(num_chunks)] * N0, + dtype=part_idx.dtype, + pin_memory=True, + ) + .to(device="cuda", non_blocking=True) + .reshape([NN0, 1]) + ) + self.part_idx_offset[key] = pio + else: + pio = self.part_idx_offset[key] + part_idx.add_(pio) + part_val = part_val.reshape([N0, k * num_chunks]) + part_idx = part_idx.reshape([N0, k * num_chunks]) + top_val, idx2 = torch.topk(part_val, k=k, sorted=sorted, dim=1) + top_idx = torch.gather(part_idx, 1, idx2) + return top_val, top_idx + + +class MHA(nn.Module): + """Multi-head self-attention and cross-attention""" + + def __init__( + self, + embed_dim, + num_heads, + cross_attn=False, + bias=True, + dropout=0.0, + softmax_scale=None, + causal=False, + layer_idx=None, + dwconv=False, + rotary_emb_dim=0, + rotary_emb_scale_base=0, + fused_bias_fc=False, + use_flash_attn=False, + return_residual=False, + checkpointing=False, + device=None, + dtype=None, + sp_kwargs=None, + ) -> None: + """ + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + assert sp_kwargs == None, "sparse predictor not support in MHA" + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.cross_attn = cross_attn + self.causal = causal + self.layer_idx = layer_idx + self.dwconv = dwconv + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.return_residual = return_residual + self.checkpointing = checkpointing + + self.num_heads = num_heads + assert ( + self.embed_dim % num_heads == 0 + ), "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + + if self.rotary_emb_dim > 0: + assert ( + not cross_attn + ), "MHA with rotary embedding does not support cross-attention yet" + assert RotaryEmbedding is not None, "rotary_emb is not installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device + ) + + if fused_bias_fc and FusedDense is None: + raise ImportError("fused_dense is not installed") + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + linear_resid_cls = ( + LinearResidual + if not fused_bias_fc + else partial(FusedDense, return_residual=True) + ) + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + if not self.cross_attn: + if not self.return_residual: + self.Wqkv = linear_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) + else: + self.Wqkv = linear_resid_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) + if self.dwconv: + self.dwconv_qkv = nn.Conv1d( + 3 * embed_dim, + 3 * embed_dim, + kernel_size=3, + padding=2, + groups=3 * embed_dim, + ) + else: + self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) + if not self.return_residual: + self.Wkv = linear_cls( + embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs + ) + else: + self.Wkv = linear_resid_cls( + embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs + ) + if self.dwconv: + self.dwconv_q = nn.Conv1d( + embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim + ) + self.dwconv_kv = nn.Conv1d( + 2 * embed_dim, + 2 * embed_dim, + kernel_size=3, + padding=2, + groups=2 * embed_dim, + ) + self.inner_attn = inner_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + # output projection always have the bias (for now) + self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) + + def _update_kv_cache(self, kv, inference_params): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + assert not self.dwconv, "Generation does not support dwconv yet" + assert ( + self.layer_idx is not None + ), "Generation requires layer_idx in the constructor" + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def forward( + self, + x, + x_kv=None, + key_padding_mask=None, + cu_seqlens=None, + max_seqlen=None, + mixer_subset=None, + inference_params=None, + **kwargs + ): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if + cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total + is the is the sum of the sequence lengths in the batch. + x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into x. Only applicable when using + FlashAttention. + max_seqlen: int. Maximum sequence length in the batch. + key_padding_mask: boolean mask, True means to keep, False means to mask out. + (batch, seqlen). Only applicable when not using FlashAttention. + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + """ + if cu_seqlens is not None: + assert max_seqlen is not None + assert key_padding_mask is None + assert self.use_flash_attn + assert not self.dwconv + assert self.rotary_emb_dim == 0 + if key_padding_mask is not None: + assert cu_seqlens is None + assert max_seqlen is None + assert not self.use_flash_attn + if inference_params is not None: + assert key_padding_mask is None + assert cu_seqlens is None and max_seqlen is None + assert not self.dwconv + + kwargs = ( + {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs} + if self.use_flash_attn + else {"key_padding_mask": key_padding_mask, **kwargs} + ) + if not self.cross_attn: + assert x_kv is None and mixer_subset is None + if not self.return_residual: + qkv = self.Wqkv(x) + else: + qkv, x = self.Wqkv(x) + if self.dwconv: + qkv = rearrange( + self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], + "b d s -> b s d", + ).contiguous() + qkv = rearrange( + qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim + ) + if inference_params is None: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv) + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint( + self.inner_attn, qkv, **kwargs + ) + else: + if ( + not inference_params.fused_ft_kernel + ) or inference_params.sequence_len_offset == 0: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=inference_params.sequence_len_offset + ) + q = qkv[:, :, 0] + kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = ( + None if inference_params.sequence_len_offset == 0 else False + ) + context = self.inner_cross_attn(q, kv, causal=causal) + else: + assert inference_params.fused_ft_kernel + assert ft_attention is not None + context = ft_attention.single_query_attention( + *rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1), + *inference_params.key_value_memory_dict[self.layer_idx], + inference_params.lengths_per_sample, + inference_params.sequence_len_offset, + self.rotary_emb_dim, + ) + context = rearrange(context, "b h d -> b 1 h d") + else: + if not self.return_residual: + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + kv = self.Wkv(x_kv if x_kv is not None else x) + else: + if x_kv is not None: + kv, x_kv = self.Wkv(x_kv) + else: + kv, x = self.Wkv(x) + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) + kv = rearrange(kv, "... (two h d) -> ... two h d", two=2, d=self.head_dim) + if self.dwconv: + q = rearrange( + self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], + "b d s -> b s d", + ).contiguous() + kv = rearrange( + self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], + "b d s -> b s d", + ).contiguous() + if inference_params is None: + if not self.checkpointing: + context = self.inner_cross_attn(q, kv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint( + self.inner_cross_attn, q, kv, **kwargs + ) + else: + kv = self._update_kv_cache(kv) + context = self.inner_cross_attn(q, kv, causal=False) + out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) + return out if not self.return_residual else (out, x) + + +class ParallelMHA(nn.Module): + """Multi-head self-attention and cross-attention""" + + def __init__( + self, + embed_dim, + num_heads, + process_group, + bias=True, + dropout=0.0, + softmax_scale=None, + causal=False, + layer_idx=None, + rotary_emb_dim=0, + rotary_emb_scale_base=0, + use_flash_attn=False, + checkpointing=False, + sequence_parallel=True, + device=None, + dtype=None, + sp_kwargs=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + assert sp_kwargs == None, "sparse predictor not support in ParallelMHA" + self.embed_dim = embed_dim + self.causal = causal + self.layer_idx = layer_idx + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.checkpointing = checkpointing + + self.num_heads = num_heads + assert ( + self.embed_dim % num_heads == 0 + ), "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + + if self.rotary_emb_dim > 0: + assert RotaryEmbedding is not None, "rotary_emb is not installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device + ) + + if ColumnParallelLinear is None or RowParallelLinear is None: + raise ImportError("fused_dense is not installed") + self.Wqkv = ColumnParallelLinear( + embed_dim, + 3 * embed_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + self.inner_attn = inner_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + # output projection always have the bias (for now) + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + process_group, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + self.num_active_heads = self.num_heads // process_group.size() + self.process_group = process_group + + def forward(self, x, seqlen=None, inference_params=None, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + if ( + inference_params is not None + and inference_params.sequence_len_offset > 0 + and self.num_active_heads + != self.num_heads // self.Wqkv.process_group.size() + ): + active_dim = self.num_active_heads * self.head_dim + # from flash_attn.utils.distributed import all_reduce + # return all_reduce(x, self.out_proj.process_group) + # return x + qkv = F.linear( + x, self.Wqkv.weight[: 3 * active_dim], self.Wqkv.bias[: 3 * active_dim] + ) + else: + qkv = self.Wqkv(x) + if seqlen is None: + qkv = rearrange( + qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim + ) + else: + qkv = rearrange( + qkv, + "(b s) (three h d) -> b s three h d", + s=seqlen, + three=3, + d=self.head_dim, + ) + if inference_params is None: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv) + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint( + self.inner_attn, qkv, **kwargs + ) + else: + if ( + not inference_params.fused_ft_kernel + ) or inference_params.sequence_len_offset == 0: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=inference_params.sequence_len_offset + ) + q = qkv[:, :, 0] + assert ( + self.layer_idx is not None + ), "Generation requires layer_idx in the constructor" + kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = None if inference_params.sequence_len_offset == 0 else False + context = self.inner_cross_attn(q, kv, causal=causal) + else: + assert inference_params.fused_ft_kernel + assert ft_attention is not None + k_cache, v_cache = inference_params.key_value_memory_dict[ + self.layer_idx + ] + if ( + self.num_active_heads + != self.num_heads // self.Wqkv.process_group.size() + ): + k_cache = k_cache[:, : self.num_active_heads] + v_cache = v_cache[:, : self.num_active_heads] + context = ft_attention.single_query_attention( + *rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1), + # *inference_params.key_value_memory_dict[self.layer_idx], + k_cache, + v_cache, + inference_params.lengths_per_sample, + inference_params.sequence_len_offset, + self.rotary_emb_dim, + ) + context = rearrange(context, "b h d -> b 1 h d") + if seqlen is None: + context = rearrange(context, "b s h d -> b s (h d)") + else: + context = rearrange(context, "b s h d -> (b s) (h d)") + if ( + inference_params is not None + and inference_params.sequence_len_offset > 0 + and self.num_active_heads + != self.num_heads // self.Wqkv.process_group.size() + ): + active_dim = self.num_active_heads * self.head_dim + out = F.linear( + context, self.out_proj.weight[:, :active_dim], self.out_proj.bias + ) + # Emma added: have to do all reduce after out projection + return all_reduce(out, self.process_group) + else: + out = self.out_proj(context) + return out + + +class ParallelMHADejavu(nn.Module): + """Multi-head self-attention and cross-attention""" + + def __init__( + self, + embed_dim, + num_heads, + process_group, + sp_kwargs=None, + bias=True, + dropout=0.0, + softmax_scale=None, + causal=False, + layer_idx=None, + rotary_emb_dim=0, + rotary_emb_scale_base=0, + use_flash_attn=False, + checkpointing=False, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.causal = causal + self.layer_idx = layer_idx + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.checkpointing = checkpointing + + self.num_heads = num_heads + assert ( + self.embed_dim % num_heads == 0 + ), "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + + if self.rotary_emb_dim > 0: + assert RotaryEmbedding is not None, "rotary_emb is not installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device + ) + + if ColumnParallelLinear is None or RowParallelLinear is None: + raise ImportError("fused_dense is not installed") + self.num_head_per_node = self.num_heads // process_group.size() + self.Wqkv = ColumnParallelLinear( + embed_dim, + 3 * embed_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + self.inner_attn = inner_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + # output projection always have the bias (for now) + self.out_proj = RowParallelLinearNoReduce( + embed_dim, + embed_dim, + process_group, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + self.process_group = process_group + + self.sp_stream = torch.cuda.Stream(device="cuda", priority=0) + self.event_out = torch.cuda.Event(enable_timing=False, blocking=False) + self.event_mlp_sp = torch.cuda.Event(enable_timing=False, blocking=False) + + def forward( + self, x, mlp_sp_logit=None, seqlen=None, inference_params=None, **kwargs + ): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + mlp_sp_logit: (b, 4*hidden_dim), calculate topk neuron to activate for MLP + head_idx: (b, k), k is the number of selected heads. + """ + curr_stream = torch.cuda.current_stream() + qkv = self.Wqkv(x) + + if seqlen is None: + qkv = rearrange( + qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim + ) + else: + qkv = rearrange( + qkv, + "(b s) (three h d) -> b s three h d", + s=seqlen, + three=3, + d=self.head_dim, + ) + if inference_params is None: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv) + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint( + self.inner_attn, qkv, **kwargs + ) + else: + if ( + not inference_params.fused_ft_kernel + ) or inference_params.sequence_len_offset == 0: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=inference_params.sequence_len_offset + ) + q = qkv[:, :, 0] + assert ( + self.layer_idx is not None + ), "Generation requires layer_idx in the constructor" + kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = None if inference_params.sequence_len_offset == 0 else False + context = self.inner_cross_attn(q, kv, causal=causal) + else: + assert inference_params.fused_ft_kernel + assert ft_attention is not None + k_cache, v_cache = inference_params.key_value_memory_dict[ + self.layer_idx + ] + context = ft_attention.single_query_attention( + *rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1), + # *inference_params.key_value_memory_dict[self.layer_idx], + k_cache, + v_cache, + inference_params.lengths_per_sample, + inference_params.sequence_len_offset, + self.rotary_emb_dim, + ) + context = rearrange(context, "b h d -> b 1 h d") + if seqlen is None: + context = rearrange(context, "b s h d -> b s (h d)") + else: + context = rearrange(context, "b s h d -> (b s) (h d)") + + out = self.out_proj(context) + curr_stream.record_event(self.event_out) + + out = all_reduce(out, self.process_group) + + mlp_idx = None + with torch.cuda.stream(self.sp_stream): + self.sp_stream.wait_event(self.event_out) + if mlp_sp_logit != None: + _, mlp_idx = mlp_sp_logit.topk(self.mlp_k, sorted=False) + + self.sp_stream.record_event(self.event_mlp_sp) + + curr_stream.wait_event(self.event_mlp_sp) + + return out, mlp_idx + + +class ParallelSP(nn.Module): + """ + A Near Neighbor Classifier + """ + + def __init__( + self, + layer_idx=None, + embed_dim=None, + low_rank_dim=None, + out_dim=None, + K=None, + process_group=None, + sequence_parallel=False, + device=None, + dtype=None, + ) -> None: + super().__init__() + assert ( + process_group is not None + ), "sparse predictor only implemented with parallel for now" + + factory_kwargs = {"device": device, "dtype": dtype} + self.process_group = process_group + self.layer_idx = layer_idx + self.embed_dim = embed_dim + self.fc0 = nn.Linear( + embed_dim, low_rank_dim, bias=False, device=device, dtype=dtype + ) + self.out = out_dim // self.process_group.size() + self.fc1 = ColumnParallelLinear( + low_rank_dim, + out_dim, + process_group, + bias=False, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + self.K = K // self.process_group.size() + + def forward(self, x): + x = self.fc0(x.view(self.embed_dim)) # b x 1000 + x = self.fc1(x) + return x + + +class ParallelTracker(nn.Module): + def __init__(self, process_group, num_head_per_node, seq_len, device=None) -> None: + self.process_group = process_group + self.num_head_per_node = num_head_per_node + self.tracker = torch.arange( + 0, seq_len, dtype=torch.int32, device=device + ).repeat(self.num_head_per_node, 1) + super().__init__() + + def get_batch_idx(self, head_idx, seq_idx): + self.tracker = self.tracker.to(head_idx.device) + return self.tracker[head_idx][:, :seq_idx] + + def update(self, head_idx, seq_idx, compute_idx): + # mark all compute position as -1 + + temp = self.tracker[:, : seq_idx + 1][head_idx] + temp[compute_idx != -1] = -1 + self.tracker[:, : seq_idx + 1][head_idx] = temp + + # self.tracker[:, seq_idx][head_idx] = -1 + + def reset(self, device): + self.tracker = torch.arange(0, 16, dtype=torch.int32, device=device).repeat( + self.num_head_per_node, 1 + ) + + +class ParallelMHASparseAttMlp(nn.Module): + """Multi-head self-attention and cross-attention""" + + def __init__( + self, + embed_dim, + num_heads, + process_group, + sp_kwargs, + bias=True, + dropout=0.0, + softmax_scale=None, + causal=False, + layer_idx=None, + rotary_emb_dim=0, + rotary_emb_scale_base=0, + use_flash_attn=False, + checkpointing=False, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + assert sp_kwargs != None, "sparse predictor parameters are not passed in." + self.embed_dim = embed_dim + self.causal = causal + self.layer_idx = layer_idx + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.checkpointing = checkpointing + self.sp_kwargs = sp_kwargs + self.num_heads = num_heads + assert ( + self.embed_dim % num_heads == 0 + ), "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + + if self.rotary_emb_dim > 0: + assert RotaryEmbedding is not None, "rotary_emb is not installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device + ) + + if ColumnParallelLinear is None or RowParallelLinear is None: + raise ImportError("fused_dense is not installed") + self.num_head_per_node = self.num_heads // process_group.size() + self.Wqkv = ColumnParallelLinear( + embed_dim, + 3 * embed_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + self.inner_attn = inner_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + # output projection always have the bias (for now) + self.out_proj = RowParallelLinearNoReduce( + embed_dim, + embed_dim, + process_group, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + self.process_group = process_group + + """sparse related""" + + # sparse predictor related + ( + self.head_idx_qkv_idx_mapping, + self.head_idx_out_idx_mapping, + ) = self.generate_idx_mapping(device) + + self.out_proj_weight_t = self.register_buffer("out_proj_weight_t", None) + self.sp = ParallelSP( + layer_idx=layer_idx, + device=device, + dtype=dtype, + process_group=process_group, + sequence_parallel=sequence_parallel, + **sp_kwargs, + ) + # cache to store previous x for up to past 16 tokens to recompute KV for skiped head + self.x_cache = torch.empty((16, embed_dim), dtype=torch.float16, device=device) + + self.to_compute_token_idx = ParallelTracker( + process_group=process_group, + num_head_per_node=self.num_head_per_node, + seq_len=16, + ) + # generation counter used to clear x_cache + self.counter = 0 + # generation counter used to copy kv cache + self.cache_offset = None + + self.sp_stream = torch.cuda.Stream(device="cuda", priority=0) + self.event_qkv = torch.cuda.Event(enable_timing=False, blocking=False) + self.event_out = torch.cuda.Event(enable_timing=False, blocking=False) + self.event_mlp_sp = torch.cuda.Event(enable_timing=False, blocking=False) + self.event_att_sp = torch.cuda.Event(enable_timing=False, blocking=False) + + def generate_idx_mapping(self, device): + # assert == 128 * (96/8) * 3, 12288 + q_idx = torch.arange( + 0, + self.embed_dim // torch.distributed.get_world_size(), + dtype=torch.long, + device=device, + ).reshape(-1, self.head_dim) + k_idx = q_idx.clone() + 1 * ( + self.embed_dim // torch.distributed.get_world_size() + ) + + v_idx = q_idx.clone() + 2 * ( + self.embed_dim // torch.distributed.get_world_size() + ) + + out_idx = q_idx.clone() + + qkv_idx = torch.cat((q_idx, k_idx, v_idx), dim=1) + + return qkv_idx, out_idx + + def forward( + self, + x, + residual, + head_idx=None, + mlp_sp_logit=None, + seqlen=None, + inference_params=None, + **kwargs + ): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + mlp_sp_logit: (b, 4*hidden_dim), calculate topk neuron to activate for MLP + head_idx: (b, k), k is the number of selected heads. + """ + do_token_generation = x.size(1) == 1 + + # prompting sequence length + if x.size(1) != 1: + self.cache_offset = torch.tensor([x.size(1)], dtype=torch.int32).to( + x.device + ) + + curr_stream = torch.cuda.current_stream() + + if ( + inference_params is not None + and inference_params.sequence_len_offset > 0 + and head_idx != None + ): + assert x.size(1) == 1 + self.x_cache[self.counter] = x + # get the token index to compute + if self.counter == 15: + # computer key value for past 16 tokens to clear x cache + self.to_compute_token_idx.reset(device=x.device) + head_idx = torch.arange( + 0, self.num_head_per_node, dtype=torch.int32, device=x.device + ) + + head_recompute_token_idx = self.to_compute_token_idx.get_batch_idx( + head_idx, self.counter + 1 + ) + + from src.ops.triton.attention_proj_sparse import qkv_proj_sparse + + qkv = qkv_proj_sparse( + self.x_cache[: self.counter + 1], + rearrange( + self.Wqkv.weight, + "(three n m) d -> three n m d", + three=3, + n=self.num_head_per_node, + m=self.head_dim, + d=self.embed_dim, + ), + head_idx, + head_recompute_token_idx, + rearrange( + self.Wqkv.bias, + "(three n m) -> three n m", + three=3, + n=self.num_head_per_node, + m=self.head_dim, + ), + ) + + # update trackers + if self.counter != 15: + self.to_compute_token_idx.update( + head_idx, self.counter, head_recompute_token_idx + ) + self.counter += 1 + else: + qkv = self.Wqkv(x) + curr_stream.record_event(self.event_qkv) + + # mlp sp topk + mlp_idx = None + with torch.cuda.stream(self.sp_stream): + self.sp_stream.wait_event(self.event_qkv) + if mlp_sp_logit != None: + _, mlp_idx = mlp_sp_logit.topk(self.mlp_k, sorted=False) + self.sp_stream.record_event(self.event_mlp_sp) + + if seqlen is None and head_idx is None: + qkv = rearrange( + qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim + ) + elif seqlen is not None: + qkv = rearrange( + qkv, + "(b s) (three h d) -> b s three h d", + s=seqlen, + three=3, + d=self.head_dim, + ) + + if inference_params is None: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv) + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint( + self.inner_attn, qkv, **kwargs + ) + else: + if ( + not inference_params.fused_ft_kernel + ) or inference_params.sequence_len_offset == 0: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=inference_params.sequence_len_offset + ) + q = qkv[:, :, 0] + assert ( + self.layer_idx is not None + ), "Generation requires layer_idx in the constructor" + kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = None if inference_params.sequence_len_offset == 0 else False + context = self.inner_cross_attn(q, kv, causal=causal) + else: + assert inference_params.fused_ft_kernel + assert ft_attention is not None + k_cache, v_cache = inference_params.key_value_memory_dict[ + self.layer_idx + ] + + from src.ops.triton.attention_proj_sparse import k_cache_copy_sparse + from src.ops.triton.attention_proj_sparse import v_cache_copy_sparse + + # copy calculate KV to kvcahce + if head_idx != None: + assert x.size(1) == 1 + # qkv: b, 3, h d + k_cache_copy_sparse( + qkv[:, 1], + k_cache, + head_idx, + head_recompute_token_idx, + self.cache_offset, + ) + v_cache_copy_sparse( + qkv[:, 2], + v_cache, + head_idx, + head_recompute_token_idx, + self.cache_offset, + ) + context = ft_attention.single_query_attention( + *qkv[-1:, :, :].unbind(dim=1), + k_cache, + v_cache, + inference_params.lengths_per_sample, + inference_params.sequence_len_offset, + self.rotary_emb_dim, + ) + else: + context = ft_attention.single_query_attention( + *rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1), + # *inference_params.key_value_memory_dict[self.layer_idx], + k_cache, + v_cache, + inference_params.lengths_per_sample, + inference_params.sequence_len_offset, + self.rotary_emb_dim, + ) + context = rearrange(context, "b h d -> b 1 h d") + if seqlen is None: + if head_idx == None: + context = rearrange(context, "b s h d -> b s (h d)") + else: + context = rearrange(context, "b s h d -> (b s) (h d)") + + if ( + inference_params is not None + and inference_params.sequence_len_offset > 0 + and head_idx != None + ): + assert context.size(0) == 1 + from src.ops.triton.attention_proj_sparse import out_proj_sparse + + out = out_proj_sparse( + context, + rearrange( + self.out_proj.weight, + "d (n m) -> d n m", + n=self.num_head_per_node, + m=self.head_dim, + ), + head_idx, + self.out_proj.bias, + ) + + # clear token cache and tracker + if self.counter == 16: + self.counter = 0 + self.x_cache = torch.empty( + (16, self.embed_dim), + device=x.device, + dtype=torch.float16, + ) + self.cache_offset += 16 + else: + out = self.out_proj(context) + + curr_stream.record_event(self.event_out) + + out = all_reduce(out, self.process_group) + + # parallel attention sparse prediction with Attention AllReduce + att_idx = None + with torch.cuda.stream(self.sp_stream): + self.sp_stream.wait_event(self.event_out) + if do_token_generation: + att_sp_logit = self.sp(residual) + _, att_idx = att_sp_logit.topk(self.att_k, sorted=False) + att_idx = att_idx.to(torch.int32) + self.sp_stream.record_event(self.event_att_sp) + curr_stream.wait_event(self.event_mlp_sp) + curr_stream.wait_event(self.event_att_sp) + return out, mlp_idx, att_idx diff --git a/HybridTensor/routers/README.md b/HybridTensor/routers/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ba16fa1649259c7cd73510fa49d1c6e7bd3375b9 --- /dev/null +++ b/HybridTensor/routers/README.md @@ -0,0 +1,46 @@ +# Train activation routers for MLP and MHA layers + +## MLP + +Update the bash file in HybridTensor/routers/mlp/train_mlp_routers.sh to include details about the job + +```bash +./HybridTensor/routers/mlp/train_mlp_routers.sh +``` + + +## MLP Router Optimization + + +A5000 +```bash +python -m HybridTensor.routers.mlp.mlp_router_optim_fast --model_index 5 \ +--batch_size 32 --delta 200 +``` + +Single batch optimization + +```bash +python -m HybridTensor.routers.mlp.mlp_router_optim --model_index 8 \ +--stats_dir results/mlp_results/batch_activations/opt-66b/ \ +--mlp_ckpt_dir /pscratch/sd/s//HybridTensor/checkpoint/opt-66b-routers/mlp/ \ +--act_data_dir /pscratch/sd/s//HybridTensor/opt-66b_act_data/ +``` + +Batched optimization + +```bash +python -m HybridTensor.routers.mlp.mlp_router_optim_fast --model_index 8 \ +--stats_dir results/mlp_results/batch_activations/opt-66b/ \ +--mlp_ckpt_dir /pscratch/sd/s//HybridTensor/checkpoint/opt-66b-routers/mlp/ \ +--act_data_dir /pscratch/sd/s//HybridTensor/opt-66b_act_data/ \ +--batch_size 2 --delta 200 --device 2 +``` + +## MHA + +Update the bash file in HybridTensor/routers/mha/train_mha_routers.sh + +```bash +./HybridTensor/routers/mha/train_mha_routers_topk.sh +``` diff --git a/HybridTensor/routers/__pycache__/router_utils.cpython-39.pyc b/HybridTensor/routers/__pycache__/router_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..794aa4e5cea66a591b4c55abe5d6093346c8a730 Binary files /dev/null and b/HybridTensor/routers/__pycache__/router_utils.cpython-39.pyc differ diff --git a/HybridTensor/routers/archive/fused_router/fused_utils.py b/HybridTensor/routers/archive/fused_router/fused_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..494ebc2f681d1f6f55f12a42c168779aad6fc51d --- /dev/null +++ b/HybridTensor/routers/archive/fused_router/fused_utils.py @@ -0,0 +1,162 @@ +import numpy as np +import os +import json +import logging + +import torch +from torch.utils.data import Dataset, DataLoader +from sklearn.model_selection import train_test_split + +# router training functions + +def get_layer_data(data_dir, layer_idx, total_samples=None): + """ + Load hidden_states, attn_norms, and mlp_activations data for a specific layer. + + Args: + data_dir (str): Directory where the data is stored. + layer_idx (int): The index of the layer to load data for. + total_samples (int, optional): The number of samples to load. If None, load all samples. + + Returns: + tuple: A tuple containing: + - hidden_states (np.ndarray): Array of shape (num_samples, hidden_size) + - attn_norms (np.ndarray): Array of shape (num_samples, num_heads) + - mlp_activations (np.ndarray): Array of shape (num_samples, hidden_size * 4) + """ + # Load metadata + metadata_path = os.path.join(data_dir, 'metadata.json') + if not os.path.exists(metadata_path): + raise FileNotFoundError(f"Metadata file not found at {metadata_path}") + + with open(metadata_path, 'r') as f: + metadata = json.load(f) + + num_layers = metadata['num_layers'] + hidden_size = metadata['hidden_size'] + num_heads = metadata['num_heads'] + max_samples = metadata['max_samples'] + + # Validate layer index + if layer_idx < 0 or layer_idx >= num_layers: + raise ValueError(f"Invalid layer_idx: {layer_idx}. Must be between 0 and {num_layers - 1}.") + + # Determine number of samples to load + if total_samples is not None: + if total_samples <= 0: + raise ValueError(f"total_samples must be a positive integer, got {total_samples}") + num_samples = min(total_samples, max_samples) + else: + num_samples = max_samples + + # File names and feature sizes + hidden_states_feature_size = hidden_size + hidden_states_filename = f"hidden_states_layer_{layer_idx}.dat" + + attn_norms_feature_size = num_heads + attn_norms_filename = f"attn_norms_layer_{layer_idx}.dat" + + mlp_activations_feature_size = hidden_size * 4 + mlp_activations_filename = f"mlp_activations_layer_{layer_idx}.dat" + + # Paths to the mmap files + hidden_states_path = os.path.join(data_dir, hidden_states_filename) + attn_norms_path = os.path.join(data_dir, attn_norms_filename) + mlp_activations_path = os.path.join(data_dir, mlp_activations_filename) + + # Load hidden_states data + if not os.path.exists(hidden_states_path): + raise FileNotFoundError(f"Hidden states file not found at {hidden_states_path}") + + logging.info(f"Reading hidden_states from {hidden_states_path}") + + hidden_states_mmap = np.memmap(hidden_states_path, dtype='float16', mode='r', + shape=(max_samples, hidden_states_feature_size)) + hidden_states = np.array(hidden_states_mmap[:num_samples]) + del hidden_states_mmap # Close the memmap + + # Load attn_norms data + if not os.path.exists(attn_norms_path): + raise FileNotFoundError(f"Attn norms file not found at {attn_norms_path}") + + logging.info(f"Reading attn_norms from {attn_norms_path}") + + attn_norms_mmap = np.memmap(attn_norms_path, dtype='float16', mode='r', + shape=(max_samples, attn_norms_feature_size)) + attn_norms = np.array(attn_norms_mmap[:num_samples]) + del attn_norms_mmap # Close the memmap + + # Load mlp_activations data + if not os.path.exists(mlp_activations_path): + raise FileNotFoundError(f"MLP activations file not found at {mlp_activations_path}") + + logging.info(f"Reading mlp_activations from {mlp_activations_path}") + + mlp_activations_mmap = np.memmap(mlp_activations_path, dtype='float16', mode='r', + shape=(max_samples, mlp_activations_feature_size)) + mlp_activations = np.array(mlp_activations_mmap[:num_samples]) + del mlp_activations_mmap # Close the memmap + + logging.info(f"Loaded {num_samples} samples for layer {layer_idx}.") + + return hidden_states, attn_norms, mlp_activations + +def create_labels(mlp_activations, attn_norms, mlp_threshold=0, attn_top_k_ratio=0.3): + # Create neuron labels + neuron_labels = (mlp_activations > mlp_threshold).astype(np.float32) + + # Create head labels + num_heads = attn_norms.shape[1] + top_k = int(num_heads * attn_top_k_ratio) + head_labels = np.zeros_like(attn_norms, dtype=np.float32) + + for i in range(attn_norms.shape[0]): + indices = np.argpartition(-attn_norms[i], top_k)[:top_k] + head_labels[i, indices] = 1.0 + + return neuron_labels, head_labels + +class FusedRouterDataset(Dataset): + def __init__(self, hidden_states, neuron_labels, head_labels): + self.hidden_states = hidden_states + self.neuron_labels = neuron_labels + self.head_labels = head_labels + + def __len__(self): + return len(self.hidden_states) + + def __getitem__(self, idx): + x = self.hidden_states[idx] + neuron_label = self.neuron_labels[idx] + head_label = self.head_labels[idx] + return x, neuron_label, head_label + +def create_dataloaders(hidden_states, neuron_labels, head_labels, batch_size=64): + # Define your split ratio and random seed + validation_split = 0.2 # 20% for validation + random_seed = 11 + + # Convert numpy arrays to tensors + hidden_states_tensor = torch.from_numpy(hidden_states).float() + neuron_labels_tensor = torch.from_numpy(neuron_labels).float() + head_labels_tensor = torch.from_numpy(head_labels).float() + + # Split the data + hidden_states_train, hidden_states_val, neuron_labels_train, neuron_labels_val, head_labels_train, head_labels_val = train_test_split( + hidden_states_tensor, + neuron_labels_tensor, + head_labels_tensor, + test_size=validation_split, + random_state=random_seed, + shuffle=True + ) + + # Create dataset instances + train_dataset = FusedRouterDataset(hidden_states_train, neuron_labels_train, head_labels_train) + val_dataset = FusedRouterDataset(hidden_states_val, neuron_labels_val, head_labels_val) + + # Create dataloaders + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + return train_loader, val_loader \ No newline at end of file diff --git a/HybridTensor/routers/archive/fused_router/router.py b/HybridTensor/routers/archive/fused_router/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6d0024cbdb2ad285daab1eae5ad2296adf3728 --- /dev/null +++ b/HybridTensor/routers/archive/fused_router/router.py @@ -0,0 +1,105 @@ +import torch + +def select_heads(heads, top_k): + _, top_indices = torch.topk(heads, top_k, dim=1) + return top_indices + +def select_neurons(neurons, th): + activated = (neurons > th).sum(dim = 0) + index_vec = activated.nonzero().flatten() + return index_vec + +class FusedRouter(torch.nn.Module): + def __init__(self, embed_dim, num_heads, dim = 1024, mlp_th = 0.5, attn_top_k = 0.3, device='cuda', dtype=torch.float16): + super(FusedRouter, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.device = device + self.dtype = dtype + self.total_neurons = embed_dim * 4 + + self.mlp_th = mlp_th + self.top_k = int(num_heads * attn_top_k) + + self.fc1 = torch.nn.Linear(embed_dim, dim, bias=False, device=device, dtype=dtype) + self.layer_norm1 = torch.nn.LayerNorm(dim, device=device, dtype=dtype) + self.fc2 = torch.nn.Linear(dim, self.total_neurons + num_heads, bias=False, device=device, dtype=dtype) + + def forward(self, x): + out = self.fc1(x) + out = self.layer_norm1(out) + out = self.fc2(out) + neurons = out[:, :self.total_neurons] + heads = out[:, self.total_neurons:] + + return neurons, heads + + def select_neurons_heads(self, x): + neurons, heads = self(x) + selected_neurons = select_neurons(neurons, self.mlp_th) + selected_heads = select_heads(heads, self.top_k) + return selected_neurons, selected_heads + +class HybridFusedRouter(torch.nn.Module): + def __init__(self, embed_dim, num_heads, mlp_dim = 1024, mha_dim = 128, mlp_th = 0.5, attn_top_k = 0.3, device='cuda', dtype=torch.float16): + super(HybridFusedRouter, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.device = device + self.dtype = dtype + self.total_neurons = embed_dim * 4 + self.mlp_dim = mlp_dim + self.mha_dim = mha_dim + + self.mlp_th = mlp_th + self.top_k = int(num_heads * attn_top_k) + + self.fc1 = torch.nn.Linear(embed_dim, mlp_dim + mha_dim, bias=False, device=device, dtype=dtype) + self.layer_norm1 = torch.nn.LayerNorm(mlp_dim, device=device, dtype=dtype) + self.activation = torch.nn.ReLU() + + self.fc2_mlp = torch.nn.Linear(mlp_dim, self.total_neurons, bias=False, device=device, dtype=dtype) + self.fc2_mha = torch.nn.Linear(mha_dim, num_heads, bias=False, device=device, dtype=dtype) + + # cache for neurons + self.mlp_neurons = None + + def forward(self, x): + out = self.fc1(x) + mlp_out, mha_out = out[:, :self.mlp_dim], out[:, self.mlp_dim:] + + # mlp router + neurons = self.layer_norm1(mlp_out) + neurons = self.fc2_mlp(mlp_out) + + # mha router + heads = self.activation(mha_out) + heads = self.fc2_mha(mha_out) + + return neurons, heads + + def select_heads_(self, x): + out = self.fc1(x) + mlp_out, mha_out = out[:, :self.mlp_dim], out[:, self.mlp_dim:] + self.mlp_neurons = mlp_out + + # mha router + heads = self.activation(mha_out) + heads = self.fc2_mha(mha_out) + + selected_heads = select_heads(heads, self.top_k) + return selected_heads + + def select_neurons_(self, x): + neurons = self.layer_norm1(self.mlp_neurons) + neurons = self.fc2_mlp(neurons) + selected_neurons = select_neurons(neurons, self.mlp_th) + return selected_neurons + + + def select_neurons_heads(self, x): + neurons, heads = self(x) + selected_neurons = select_neurons(neurons, self.mlp_th) + selected_heads = select_heads(heads, self.top_k) + return selected_neurons, selected_heads + diff --git a/HybridTensor/routers/archive/run_c4_att.sh b/HybridTensor/routers/archive/run_c4_att.sh new file mode 100644 index 0000000000000000000000000000000000000000..75a6ce364eaf73cbd087c54d9edcdff6a0de5edb --- /dev/null +++ b/HybridTensor/routers/archive/run_c4_att.sh @@ -0,0 +1,13 @@ +for l in $(seq 0 8 64) +do + (trap 'kill 0' SIGINT; \ + CUDA_VISIBLE_DEVICES=0 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L ${l} > logs/c4_att_out_${l}.txt & \ + CUDA_VISIBLE_DEVICES=1 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L $((l+1)) > logs/c4_att_out_$((l+1)).txt & \ + CUDA_VISIBLE_DEVICES=2 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L $((l+2)) > logs/c4_att_out_$((l+2)).txt & \ + CUDA_VISIBLE_DEVICES=3 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L $((l+3)) > logs/c4_att_out_$((l+3)).txt & \ + CUDA_VISIBLE_DEVICES=4 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L $((l+4)) > logs/c4_att_out_$((l+4)).txt & \ + CUDA_VISIBLE_DEVICES=5 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L $((l+5)) > logs/c4_att_out_$((l+5)).txt & \ + CUDA_VISIBLE_DEVICES=6 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L $((l+6)) > logs/c4_att_out_$((l+6)).txt & \ + CUDA_VISIBLE_DEVICES=7 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L $((l+7)) > logs/c4_att_out_$((l+7)).txt & \ + wait) +done \ No newline at end of file diff --git a/HybridTensor/routers/archive/run_c4_att_custom.sh b/HybridTensor/routers/archive/run_c4_att_custom.sh new file mode 100644 index 0000000000000000000000000000000000000000..d898fe9f1c150ec0233f2a1c4b1949690aee0b6e --- /dev/null +++ b/HybridTensor/routers/archive/run_c4_att_custom.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# update this based on the model and the number of layers +total_num_layers=32 + +for l in $(seq 0 4 32) +do + (trap 'kill 0' SIGINT; \ + CUDA_VISIBLE_DEVICES=0 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L ${l} > logs/c4_att_out_${l}.txt & \ + CUDA_VISIBLE_DEVICES=1 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L $((l+1)) > logs/c4_att_out_$((l+1)).txt & \ + CUDA_VISIBLE_DEVICES=0 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L $((l+2)) > logs/c4_att_out_$((l+2)).txt & \ + CUDA_VISIBLE_DEVICES=1 python main_att.py --dataset c4 --lr 0.0001 --k 0.3 --L $((l+3)) > logs/c4_att_out_$((l+3)).txt & \ + + wait) +done \ No newline at end of file diff --git a/HybridTensor/routers/archive/run_c4_mlp.sh b/HybridTensor/routers/archive/run_c4_mlp.sh new file mode 100644 index 0000000000000000000000000000000000000000..654136fa504b74b490f6f17872de118d3173716b --- /dev/null +++ b/HybridTensor/routers/archive/run_c4_mlp.sh @@ -0,0 +1,13 @@ +for l in $(seq 0 8 64) +do + (trap 'kill 0' SIGINT; \ + CUDA_VISIBLE_DEVICES=0 python main_mlp.py --dataset c4 --lr 0.001 --L ${l} > logs/c4_mlp_out_${l}.txt & \ + CUDA_VISIBLE_DEVICES=1 python main_mlp.py --dataset c4 --lr 0.001 --L $((l+1)) > logs/c4_mlp_out_$((l+1)).txt & \ + CUDA_VISIBLE_DEVICES=2 python main_mlp.py --dataset c4 --lr 0.001 --L $((l+2)) > logs/c4_mlp_out_$((l+2)).txt & \ + CUDA_VISIBLE_DEVICES=3 python main_mlp.py --dataset c4 --lr 0.001 --L $((l+3)) > logs/c4_mlp_out_$((l+3)).txt & \ + CUDA_VISIBLE_DEVICES=4 python main_mlp.py --dataset c4 --lr 0.001 --L $((l+4)) > logs/c4_mlp_out_$((l+4)).txt & \ + CUDA_VISIBLE_DEVICES=5 python main_mlp.py --dataset c4 --lr 0.001 --L $((l+5)) > logs/c4_mlp_out_$((l+5)).txt & \ + CUDA_VISIBLE_DEVICES=6 python main_mlp.py --dataset c4 --lr 0.001 --L $((l+6)) > logs/c4_mlp_out_$((l+6)).txt & \ + CUDA_VISIBLE_DEVICES=7 python main_mlp.py --dataset c4 --lr 0.001 --L $((l+7)) > logs/c4_mlp_out_$((l+7)).txt & \ + wait) +done diff --git a/HybridTensor/routers/archive/run_c4_mlp_custom.sh b/HybridTensor/routers/archive/run_c4_mlp_custom.sh new file mode 100644 index 0000000000000000000000000000000000000000..f75ba21923c36d431aa2dc92a8f244ff2b9b96c5 --- /dev/null +++ b/HybridTensor/routers/archive/run_c4_mlp_custom.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# update this based on the model and the number of layers +total_num_layers=32 + +for l in $(seq 0 2 $total_num_layers) +do + (trap 'kill 0' SIGINT; \ + CUDA_VISIBLE_DEVICES=0 python main_mlp.py --dataset c4 --lr 0.001 --L ${l} > logs/c4_mlp_out_${l}.txt & \ + CUDA_VISIBLE_DEVICES=1 python main_mlp.py --dataset c4 --lr 0.001 --L $((l+1)) > logs/c4_mlp_out_$((l+1)).txt & \ + wait) +done \ No newline at end of file diff --git a/HybridTensor/routers/datacollection/data_collection.py b/HybridTensor/routers/datacollection/data_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc70f9e5325ed16705b0994c031be2dd9ac40a0 --- /dev/null +++ b/HybridTensor/routers/datacollection/data_collection.py @@ -0,0 +1,442 @@ +import torch +import numpy as np +import os + +from hf_models.opt.modeling_opt import OPTForCausalLM +from hf_models.llama.modeling_llama import LlamaForCausalLM +from transformers import AutoTokenizer + +# from tests.activation_tests.eval_sparse_perplexity import OPT_MODELS, build_data_loader +from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import build_data_loader +from HybridTensor.utils.utils import extract_model_name + +from datasets import load_dataset +import json + +from tqdm import tqdm +import argparse +from HybridTensor.utils.activations import MODELS + + +def load_layer_data(data_dir, layer_idx, data_type): + """ + Load data for a specific layer and data type. + + Args: + data_dir (str): Directory where data is stored. + layer_idx (int): Layer index. + data_type (str): One of 'hidden_states', 'mlp_activations', 'attn_norms'. + + Returns: + np.ndarray: The data array of shape (num_samples, feature_size). + """ + # Load metadata + metadata_filename = os.path.join(data_dir, 'metadata.json') + with open(metadata_filename, 'r') as f: + metadata = json.load(f) + + num_layers = metadata['num_layers'] + hidden_size = metadata['hidden_size'] + num_heads = metadata['num_heads'] + max_samples = metadata['max_samples'] + + # Validate layer index + if layer_idx < 0 or layer_idx >= num_layers: + raise ValueError(f"Invalid layer_idx: {layer_idx}. Must be between 0 and {num_layers - 1}.") + + # Get the sample count and feature size + if data_type == 'hidden_states': + sample_counts = metadata['hidden_states_counters'] + sample_count = sample_counts[layer_idx] + feature_size = hidden_size + elif data_type == 'mlp_activations': + sample_counts = metadata['mlp_activations_counters'] + sample_count = sample_counts[layer_idx] + feature_size = hidden_size * 4 + elif data_type == 'attn_norms': + sample_counts = metadata['attn_norms_counters'] + sample_count = sample_counts[layer_idx] + feature_size = num_heads + else: + raise ValueError(f"Invalid data_type: {data_type}. Must be 'hidden_states', 'mlp_activations', or 'attn_norms'.") + + # Load the data file + filename = os.path.join(data_dir, f'{data_type}_layer_{layer_idx}.dat') + data_mmap = np.memmap(filename, dtype='float16', mode='r', shape=(max_samples, feature_size)) + + # Return only the valid data + data = np.array(data_mmap[:sample_count]) + del data_mmap # Close the memmap + return data + +def initialize_data_structures(data_dir, num_layers, hidden_size, num_heads, num_neurons, max_samples, + mlp_activation=True, attn_norm=True): + """ + Initialize mmap files and counters for hidden_states, mlp_activations, and attn_norms. + + Args: + data_dir (str): Directory where data is stored. + num_layers (int): Number of transformer layers. + hidden_size (int): Hidden size of the model. + num_heads (int): Number of attention heads. + max_samples (int): Maximum number of samples to collect. + + Returns: + tuple: Contains lists of mmap files and counters for each data type. + """ + # Initialize lists to hold mmap files and counters + hidden_states_files = [] + mlp_activations_files = [] + attn_norms_files = [] + + hidden_states_counters = [] + mlp_activations_counters = [] + attn_norms_counters = [] + + for layer_idx in range(num_layers): + # Hidden states + hs_filename = os.path.join(data_dir, f'hidden_states_layer_{layer_idx}.dat') + hs_file = np.memmap(hs_filename, dtype='float16', mode='w+', shape=(max_samples, hidden_size)) + hidden_states_files.append(hs_file) + hidden_states_counters.append(0) + + # MLP activations + if mlp_activation: + mlp_filename = os.path.join(data_dir, f'mlp_activations_layer_{layer_idx}.dat') + mlp_file = np.memmap(mlp_filename, dtype='float16', mode='w+', shape=(max_samples, num_neurons)) + mlp_activations_files.append(mlp_file) + mlp_activations_counters.append(0) + + # Attention norms + if attn_norm: + attn_filename = os.path.join(data_dir, f'attn_norms_layer_{layer_idx}.dat') + attn_file = np.memmap(attn_filename, dtype='float16', mode='w+', shape=(max_samples, num_heads)) + attn_norms_files.append(attn_file) + attn_norms_counters.append(0) + + return ( + hidden_states_files, + hidden_states_counters, + mlp_activations_files, + mlp_activations_counters, + attn_norms_files, + attn_norms_counters + ) + +def process_hidden_states(layer_idx, hidden_states_layer, valid_token_indices, hidden_size, hidden_states_files, hidden_states_counters): + """ + Process and store hidden states for a specific layer. + """ + hs = hidden_states_layer.view(-1, hidden_size) # Shape: (batch_size * seq_len, hidden_size) + hs_valid = hs[valid_token_indices.cpu()] # Select valid tokens + hs_counter = hidden_states_counters[layer_idx] + hs_file = hidden_states_files[layer_idx] + hs_file[hs_counter:hs_counter+hs_valid.shape[0], :] = hs_valid.cpu().numpy().astype('float16') + hidden_states_counters[layer_idx] += hs_valid.shape[0] + +def process_mlp_activations(layer_idx, mlp_activations_layer, valid_token_indices, hidden_size, mlp_activations_files, mlp_activations_counters): + """ + Process and store MLP activations for a specific layer. + """ + neuron_shape = mlp_activations_layer.shape[-1] + mlp_activations_layer = mlp_activations_layer.view(-1, neuron_shape) # Shape: (batch_size * seq_len, hidden_size) + + mlp_valid = mlp_activations_layer[valid_token_indices.cpu()] # Select valid tokens + mlp_counter = mlp_activations_counters[layer_idx] + mlp_file = mlp_activations_files[layer_idx] + mlp_file[mlp_counter:mlp_counter+mlp_valid.shape[0], :] = mlp_valid.cpu().numpy().astype('float16') + mlp_activations_counters[layer_idx] += mlp_valid.shape[0] + +def process_attn_norms(layer_idx, attn_outputs_layer, valid_token_indices, num_heads, attn_norms_files, attn_norms_counters): + """ + Process and store attention norms for a specific layer. + """ + # attn_outputs_layer has shape (batch_size, seq_len, num_heads, head_dim) + attn = attn_outputs_layer # Shape: (batch_size, seq_len, num_heads, head_dim) + attn_norms = torch.norm(attn, dim=-1) # Shape: (batch_size, seq_len, num_heads) + attn_norms = attn_norms.view(-1, num_heads) # Shape: (batch_size * seq_len, num_heads) + attn_valid = attn_norms[valid_token_indices.cpu()] # Select valid tokens + attn_counter = attn_norms_counters[layer_idx] + attn_file = attn_norms_files[layer_idx] + attn_file[attn_counter:attn_counter+attn_valid.shape[0], :] = attn_valid.cpu().numpy().astype('float16') + attn_norms_counters[layer_idx] += attn_valid.shape[0] + +def process_batch( + outputs, + input_ids, + attention_mask, + total_samples, + max_samples, + num_layers, + hidden_size, + num_heads, + hidden_states_files, + hidden_states_counters, + mlp_activations_files, + mlp_activations_counters, + attn_norms_files, + attn_norms_counters, + args +): + """ + Process a batch of model outputs and update the data files. + + Returns: + total_samples (int): Updated total number of samples processed. + reached_max_samples (bool): Indicates if the maximum number of samples has been reached. + """ + # Extract the outputs + # hidden_states = outputs['hidden_states'][:-1] # Ignore the last hidden state + + # use router_inputs instead of hidden_states + hidden_states = outputs['router_inputs'] + + if args.mlp_activation: + mlp_activations = outputs['mlp_activations'] # Tuple of MLP activations + else: + mlp_activations = None + + if args.attn_norm: + attn_outputs = outputs['attn_outputs'] # Tuple of attention outputs + else: + attn_outputs = None + + batch_size, seq_len = input_ids.shape + + # Flatten attention_mask to (batch_size * seq_len) + attention_mask_flat = attention_mask.view(-1).bool() + num_valid_tokens = attention_mask_flat.sum().item() + + # Determine tokens to process + if total_samples + num_valid_tokens >= max_samples: + tokens_to_process = max_samples - total_samples + total_samples = max_samples + reached_max_samples = True + else: + tokens_to_process = num_valid_tokens + total_samples += num_valid_tokens + reached_max_samples = False + + # Find indices of valid tokens + valid_token_indices = attention_mask_flat.nonzero(as_tuple=False).view(-1) + # If tokens_to_process < num_valid_tokens, limit the indices + valid_token_indices = valid_token_indices[:tokens_to_process] + + for layer_idx in range(num_layers): + # Process hidden states + process_hidden_states( + layer_idx, + hidden_states[layer_idx], + valid_token_indices, + hidden_size, + hidden_states_files, + hidden_states_counters + ) + if args.mlp_activation: + # Process MLP activations + process_mlp_activations( + layer_idx, + mlp_activations[layer_idx], + valid_token_indices, + hidden_size, + mlp_activations_files, + mlp_activations_counters + ) + + if args.attn_norm: + # Process attention norms + process_attn_norms( + layer_idx, + attn_outputs[layer_idx], + valid_token_indices, + num_heads, + attn_norms_files, + attn_norms_counters + ) + + return total_samples, reached_max_samples, num_valid_tokens + +def finalize_data_collection( + data_dir, + num_layers, + hidden_size, + num_heads, + max_samples, + hidden_states_files, + mlp_activations_files, + attn_norms_files, + hidden_states_counters, + mlp_activations_counters, + attn_norms_counters, + args +): + """ + Finalize the data collection by flushing and closing mmap files and saving metadata. + + Args: + data_dir (str): Directory where data is stored. + num_layers (int): Number of transformer layers. + hidden_size (int): Hidden size of the model. + num_heads (int): Number of attention heads. + max_samples (int): Maximum number of samples to collect. + hidden_states_files (list): List of mmap files for hidden states. + mlp_activations_files (list): List of mmap files for MLP activations. + attn_norms_files (list): List of mmap files for attention norms. + hidden_states_counters (list): List of counters for hidden states. + mlp_activations_counters (list): List of counters for MLP activations. + attn_norms_counters (list): List of counters for attention norms. + """ + + for layer_idx in range(num_layers): + # Hidden states + hs_file = hidden_states_files[layer_idx] + hs_file.flush() + del hs_file # Close the memmap + + if args.mlp_activation: + # MLP activations + mlp_file = mlp_activations_files[layer_idx] + mlp_file.flush() + del mlp_file # Close the memmap + + if args.attn_norm: + # Attention norms + attn_file = attn_norms_files[layer_idx] + attn_file.flush() + del attn_file # Close the memmap + + # Prepare metadata + metadata = { + 'num_layers': num_layers, + 'hidden_size': hidden_size, + 'num_heads': num_heads, + 'max_samples': max_samples, + 'hidden_states_counters': hidden_states_counters, + 'mlp_activations_counters': mlp_activations_counters, + 'attn_norms_counters': attn_norms_counters + } + + # Save metadata to a JSON file + metadata_filename = os.path.join(data_dir, 'metadata.json') + with open(metadata_filename, 'w') as f: + json.dump(metadata, f) + + print("Finalization complete. Metadata saved.") + +def arg_parser(): + parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation') + parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate') + parser.add_argument('--batch_size', type=int, default=4, help='Batch size for evaluation') + parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length') + parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds') + parser.add_argument('--device_map', type=str, default='cuda:0', help='Device to use for evaluation') + parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection') + parser.add_argument('--data_dir', type=str, default='', help='Directory to store generated data') + parser.add_argument('--max_samples', type=int, default=5000, help='Maximum number of samples to collect') + parser.add_argument('--model_family', type=str, default='opt', choices= ["opt", "llama"], help='Model family to evaluate') + parser.add_argument('--mlp_activation', type=bool, default=False, help='Collect MLP activations') + parser.add_argument('--attn_norm', type=bool, default=True, help='Collect attention norms') + + return parser.parse_args() + +if __name__ =="__main__": + args = arg_parser() + model_name = MODELS[args.model_index-1] + batch_size = args.batch_size + max_length = args.max_length + data_collection = args.data_collection + device_map = args.device_map + + # Load the model + if args.model_family == 'opt': + model = OPTForCausalLM.from_pretrained( + model_name, device_map=device_map, torch_dtype=torch.float16, + attn_implementation="flash_attention_2", output_hidden_states=True, output_attentions=True, + return_dict=True + ) + num_neurons = model.config.ffn_dim + + elif args.model_family == 'llama': + model = LlamaForCausalLM.from_pretrained( + model_name, device_map=device_map, torch_dtype=torch.float16, + attn_implementation="flash_attention_2", output_hidden_states=True, output_attentions=True, + return_dict=True + ) + num_neurons = model.config.intermediate_size + + data_loader = build_data_loader( + model_name, "wikitext", "wikitext-2-raw-v1", batch_size, max_length, split='train' + ) + if args.device_map == "auto": + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + device = torch.device(device_map if torch.cuda.is_available() else 'cpu') + + # Ensure data directory exists + model_name_clean = extract_model_name(model_name) + folder_name = f"{model_name_clean}_act_data" + data_dir = os.path.join(args.data_dir, folder_name) + + if not os.path.exists(data_dir): + os.makedirs(data_dir) + + num_layers = model.config.num_hidden_layers + hidden_size = model.config.hidden_size + num_heads = model.config.num_attention_heads + max_samples = args.max_samples + + # print if we are using mlp activations and attn norms + print(f"Collecting data for model: {model_name}") + print(f"Data directory: {data_dir}") + print(f"Number of layers: {num_layers}") + print(f"Hidden size: {hidden_size}") + print(f"Number of heads: {num_heads}") + print(f"Number of neurons: {num_neurons}") + print(f"Max samples: {max_samples}") + print(f"Collecting MLP activations: {args.mlp_activation}") + print(f"Collecting attention norms: {args.attn_norm}") + + # Initialize data structures using the function + (hidden_states_files, hidden_states_counters, mlp_activations_files, + mlp_activations_counters, attn_norms_files, attn_norms_counters) = initialize_data_structures(data_dir, num_layers, + hidden_size, num_heads, num_neurons, max_samples, + mlp_activation=args.mlp_activation, attn_norm=args.attn_norm) + + total_samples = 0 + # Data collection loop + with torch.no_grad(): + with tqdm(total=max_samples, desc="Router training data collection") as pbar: + for batch in data_loader: + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + + outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, + output_attentions=False, return_dict=True, output_mlp_activation=args.mlp_activation, + output_attn_output=args.attn_norm, output_router_inputs=True) + + # Process the batch outputs + total_samples, reached_max_samples, num_valid_tokens = process_batch( + outputs, input_ids, attention_mask, + total_samples, max_samples, num_layers, + hidden_size, num_heads, hidden_states_files, + hidden_states_counters, mlp_activations_files, mlp_activations_counters, + attn_norms_files, attn_norms_counters, + args=args) + + pbar.update(num_valid_tokens) + + if reached_max_samples: + break + + # Finalization: Flush and clean up using the new function + finalize_data_collection( + data_dir, num_layers, hidden_size, + num_heads, max_samples, hidden_states_files, + mlp_activations_files, attn_norms_files, hidden_states_counters, + mlp_activations_counters, attn_norms_counters, + args + ) + + if reached_max_samples: + print(f"Reached maximum number of samples. total_samples = {total_samples}") + print(f"Data collection complete. Data saved to {data_dir}") \ No newline at end of file diff --git a/HybridTensor/routers/mha/main_att.py b/HybridTensor/routers/mha/main_att.py new file mode 100644 index 0000000000000000000000000000000000000000..167575949a77c4e48568c9510b018ba67977d848 --- /dev/null +++ b/HybridTensor/routers/mha/main_att.py @@ -0,0 +1,91 @@ +import torch +import numpy as np +import argparse +import random +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from HybridTensor.routers.router_utils import get_data_, create_dataset +from HybridTensor.routers.mha.trainer_att import train +from HybridTensor.utils.utils import _get_device, extract_model_name +# from HybridTensor.utils.activations import OPT_CONFIGS, OPT_MODELS +from HybridTensor.utils.activations import CONFIGS, MODELS +from HybridTensor.routers.mha.main_attn_regression import MHA_Router, MHA_Router_Linear + +import os + + +def arg_parser(): + parser = argparse.ArgumentParser(description=" MHA Router Training") + parser.add_argument("--model", type=str, default="facebook/opt-6.7b", choices=MODELS) + parser.add_argument("--model_name", type=str, default="opt", help="model name") + parser.add_argument("--model_index", type=int, default=8, help="model index") + parser.add_argument("--L", type=int, default=4, help="which layer") + parser.add_argument("--D", type=int, default=1024, help="low rank dimension") + parser.add_argument("--batch_size", type=int, default=512, help="batch size") # a smaller batch size results in better precision and recall of the model + parser.add_argument("--epochs", type=int, default=5, help="epochs") + parser.add_argument("--lr", type=float, default=0.0001, help="learning rate") + parser.add_argument("--k", type=float, default=0.3, help="top k percent to mark as activate head") # top k from attention sweep + parser.add_argument("--hidden_activation", type=str, default="none", choices=["none", "relu", "gelu"]) + parser.add_argument("--data_dir", type=str, default="", help="data directory") # add an argument for data dir + parser.add_argument("--ckpt_dir", type=str, default="", help="checkpoint directory") # add an argument for checkpoint dir + parser.add_argument("--total_samples", type=int, default=0, help="total samples") # add a argument for total samples + parser.add_argument("--gpu", type=int, default=0, help="which gpu to use") # add a argument for which gpu to use + parser.add_argument("--norm", type=str, default="none", choices=["layer", "batch", "none"]) + parser.add_argument("--router_arch", type=str, default="linear", choices=["mlp", "linear"]) + + return parser.parse_args() + + +def main(): + + args = arg_parser() + print(args) + random.seed(0) + torch.manual_seed(0) + np.random.seed(0) + + # load the attention activation data + if args.total_samples == 0: + args.total_samples = 400000 + else: + args.total_samples = args.total_samples + + model_name = MODELS[args.model_index -1] + model_config = CONFIGS[model_name] + model_dim = model_config['d'] + num_heads = model_config['h'] + inner_dim = args.D + device = _get_device(args.gpu) + + print("=" * 40, "Layer", args.L, "=" * 40) + + # query, labels = get_data(args, args.L) + hidden_states, attn_norms = get_data_(args.data_dir, args.L, data_type="attn_norms", total_samples=args.total_samples) + train_loader, test_loader = create_dataset(hidden_states, attn_norms, args) + + if args.router_arch == "mlp": + mha_router = MHA_Router(model_dim, inner_dim, num_heads, norm=args.norm, hidden_activation=args.hidden_activation) + else: + mha_router = MHA_Router_Linear(model_dim, inner_dim, num_heads) + + print(mha_router) + print("Start Training") + best_model, eval_result = train( + mha_router, train_loader, test_loader, args, device, verbal=True + ) + + # save the checkpoint + model_name_clean = extract_model_name(model_name) + ckpt_path = f"{args.ckpt_dir}/{model_name_clean}-routers/attn_classifier" + + # create the checkpoint directory if it does not exist + if not os.path.exists(ckpt_path): + os.makedirs(ckpt_path) + + file_name = f"{ckpt_path}/attn_router_{args.L}-{eval_result['Recall']:.4f}-{eval_result['Precision']:.4f}.pt" + torch.save(best_model, file_name) + print(f"Model saved at {file_name}") + + +if __name__ == "__main__": + main() diff --git a/HybridTensor/routers/mha/train_mha_routers_topk.sh b/HybridTensor/routers/mha/train_mha_routers_topk.sh new file mode 100644 index 0000000000000000000000000000000000000000..aaac295dfc03ab9a57ece42bac0daab9c1e510ac --- /dev/null +++ b/HybridTensor/routers/mha/train_mha_routers_topk.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# Number of GPUs available +NUM_GPUS=2 + +# Each GPU can handle JOBS_PER_GPU concurrent trainings +JOBS_PER_GPU=8 + +# Total layers you want to process +START_LAYER=0 +END_LAYER=31 + +MODEL_INDEX=14 +D=128 +DATA_DIR="" +CKPT_DIR="" + +BATCH_SIZE=$((NUM_GPUS * JOBS_PER_GPU)) + +# print all the configs +echo "NUM_GPUS: $NUM_GPUS, JOBS_PER_GPU: $JOBS_PER_GPU, START_LAYER: $START_LAYER, END_LAYER: $END_LAYER, MODEL_INDEX: $MODEL_INDEX, DATA_DIR: $DATA_DIR, CKPT_DIR: $CKPT_DIR, BATCH_SIZE: $BATCH_SIZE" + +layer=$START_LAYER +while [ $layer -le $END_LAYER ]; do + # Launch a batch of jobs + for (( i=0; i<$BATCH_SIZE && layer<=$END_LAYER; i++ )); do + GPU_ID=$(( i / JOBS_PER_GPU )) + CUDA_VISIBLE_DEVICES=$GPU_ID python -m HybridTensor.routers.mha.main_att\ + --model_index $MODEL_INDEX \ + --L $layer \ + --D $D \ + --k 0.5 \ + --data_dir $DATA_DIR \ + --ckpt_dir $CKPT_DIR & + layer=$((layer+1)) + done + + # Wait for this batch to finish + wait +done + +echo "All training jobs finished!" \ No newline at end of file diff --git a/HybridTensor/routers/mha/trainer_att.py b/HybridTensor/routers/mha/trainer_att.py new file mode 100644 index 0000000000000000000000000000000000000000..25da232968e2be99bb379e09457a41d04d57673c --- /dev/null +++ b/HybridTensor/routers/mha/trainer_att.py @@ -0,0 +1,316 @@ +import torch +import copy +import numpy as np +from tqdm import tqdm + +from torch.cuda.amp import autocast, GradScaler + +def eval_print(validation_results): + result = "" + for metric_name, metirc_val in validation_results.items(): + result = f"{result}{metric_name}: {metirc_val:.4f} " + return result + +def generate_label(norm, k): + K=int(norm.shape[1]*k) + indices = norm.topk(K, dim=1).indices + one_hot = torch.zeros_like(norm).scatter_(1, indices, 1) + return one_hot + + +def evaluate_better(model, device, loader, args, smalltest=False): + model.eval() + + eval = { + "Loss": [], + "Loss Weight": [], + "Precision": [], + "Recall": [], + "F1 Score": [], + "Classifier Sparsity": [], + "True Sparsity": [], + } + with torch.no_grad(): + for batch_idx, batch in enumerate(loader): + x, y = batch + y = y.float().to(device) + y = generate_label(y, args.k) + + logits = model(x.to(device)) + probs = logits.sigmoid() + preds = probs >= 0.5 + + preds = preds.int() + y = y.int() + + dif = y - preds + miss = dif > 0.0 # classifier didn't activate target head + + weight = y.sum() / y.numel() + loss_weight = y * (1 - weight) + weight + eval["Loss Weight"].append(weight.item()) + eval["Loss"].append( + torch.nn.functional.binary_cross_entropy( + probs, y.float(), weight=loss_weight + ).item() + ) + + TP = (preds & y).sum(dim=1).float() + FP = (preds & (1 - y)).sum(dim=1).float() + FN = ((1 - preds) & y).sum(dim=1).float() + + total_predicted_positives = preds.sum(dim=1).float() + total_actual_positives = y.sum(dim=1).float() + + per_sample_precision = torch.where( + total_predicted_positives > 0, + TP / total_predicted_positives, + torch.zeros_like(total_predicted_positives) + ) + per_sample_recall = torch.where( + total_actual_positives > 0, + TP / total_actual_positives, + torch.zeros_like(total_actual_positives) + ) + per_sample_f1 = torch.where( + (per_sample_precision + per_sample_recall) > 0, + 2 * per_sample_precision * per_sample_recall / (per_sample_precision + per_sample_recall), + torch.zeros_like(per_sample_precision) + ) + + eval["Precision"].append(per_sample_precision.mean().item()) + eval["Recall"].append(per_sample_recall.mean().item()) + eval["F1 Score"].append(per_sample_f1.mean().item()) + eval["True Sparsity"].append(total_actual_positives.mean().item()) + eval["Classifier Sparsity"].append(total_predicted_positives.mean().item()) + + if batch_idx >= 100 and smalltest: + break + + for k, v in eval.items(): + eval[k] = np.array(v).mean() + + return eval + + +def train(model, train_loader, valid_loader, args, device, verbal=True): + num_val = 0 + early_stop_waiting = 5 + val_inter = len(train_loader) // (num_val + 1) + 1 + num_print = 0 + print_inter = len(train_loader) // (num_print + 1) + 1 + model.to(device) + + optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=0.01) + + eval_results = evaluate_better(model, device, valid_loader, args, smalltest=True) + if verbal: + print(f"[Start] {eval_print(eval_results)}") + + best_model = copy.deepcopy(model.state_dict()) + base_acc = eval_results["Recall"] + best_eval = eval_results + no_improve = 0 + + for e in range(args.epochs): + + for batch_idx, batch in enumerate(tqdm(train_loader)): + model.train() + x, y = batch + optimizer.zero_grad() + y = y.float().to(device) + y = generate_label(y, args.k) + logits = model(x.to(device)) + probs = logits.sigmoid() + + weight = ( y.sum() / y.numel()) + loss_weight = y * (1 - weight) + weight + loss = torch.nn.functional.binary_cross_entropy( + probs, y, weight=loss_weight + ) + loss.backward() + optimizer.step() + + if (batch_idx + 1) % print_inter == 0: + print( + f"[ {e}, {batch_idx}] Loss: {loss.item():.4f}, Loss weight: { weight.item():.4f}" + ) + + if ( + (batch_idx + 1) % val_inter == 0 + and batch_idx != 0 + and batch_idx != len(train_loader) + ): + valid_eval_results = evaluate_better( + model, device, valid_loader, smalltest=False + ) + train_eval_results = evaluate_better( + model, device, train_loader, smalltest=True + ) + model.train() + print( + f"[{e}, {batch_idx}] Validation: {eval_print(valid_eval_results)}" + ) + print(f"[{e}, {batch_idx}] Train: {eval_print(train_eval_results)}") + + train_eval_results = evaluate_better(model, device, train_loader, args, smalltest=True) + epoch_eval_results = evaluate_better( + model, device, valid_loader, args, smalltest=False + ) + if verbal: + print(f"[Epoch {e+1}] [Train] {eval_print(train_eval_results)}") + print(f"[Epoch {e+1}] [Valid] {eval_print(epoch_eval_results)}\n") + + if epoch_eval_results["Recall"] > base_acc: + base_acc = epoch_eval_results["Recall"] + best_eval = epoch_eval_results + model.cpu() + best_model = copy.deepcopy(model.state_dict()) + model.to(device) + no_improve = 0 + else: + no_improve += 1 + + if no_improve >= early_stop_waiting or base_acc > 0.99: + break + + return best_model, best_eval + +def evaluate(model, device, loader, args, smalltest=False): + model.eval() + + eval = { + "Loss": [], + "Loss Weight": [], + "Recall": [], + "Classifier Sparsity": [], + "True Sparsity": [], + } + with torch.no_grad(): + for batch_idx, batch in enumerate((loader)): + x, y = batch + y = y.float().to(device) + y = generate_label(y, args.k) + + logits = model(x.to(device)) + probs = logits.sigmoid() + preds = probs >= 0.5 + + dif = y.int() - preds.int() + miss = dif > 0.0 # classifier didn't activated target head + + + weight = ( y.sum() / y.numel()) + loss_weight = y * (1 - weight) + weight + eval["Loss Weight"] += [ + weight.item() + ] + eval["Loss"] += [ + torch.nn.functional.binary_cross_entropy( + probs, y, weight=loss_weight + ).item() + ] + + eval["Recall"] += [ + ((y.sum(dim=1).float() - miss.sum(dim=1).float()).mean().item()) + ] + eval["True Sparsity"] += [y.sum(dim=1).float().mean().item()] + eval["Classifier Sparsity"] += [preds.sum(dim=1).float().mean().item()] + + if batch_idx >= 100 and smalltest: + break + + for k, v in eval.items(): + eval[k] = np.array(v).mean() + + eval["Recall"] = eval["Recall"] / eval["True Sparsity"] + return eval + + +def train_v1(model, train_loader, valid_loader, args, device, verbal=True): + early_stop_waiting = 5 + num_val = 0 + val_inter = len(train_loader) // (num_val + 1) + 1 + num_print = 0 + print_inter = len(train_loader) // (num_print + 1) + 1 + model.to(device) + + optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=0.01) + scaler = GradScaler() # Initialize GradScaler + + eval_results = evaluate(model, device, valid_loader, args, smalltest=True) + if verbal: + print(f"[Start] {eval_print(eval_results)}") + + best_model = copy.deepcopy(model.state_dict()) + base_acc = eval_results["Recall"] + best_eval = eval_results + no_improve = 0 + + for e in range(args.epochs): + + for batch_idx, batch in enumerate(tqdm(train_loader)): + model.train() + x, y = batch + optimizer.zero_grad() + y = y.float().to(device) + y = generate_label(y, args.k) + + with autocast(): # Enable autocast + logits = model(x.to(device)) + probs = logits.sigmoid() + + weight = (y.sum() / y.numel()) + loss_weight = y * (1 - weight) + weight + loss = torch.nn.functional.binary_cross_entropy( + probs, y, weight=loss_weight + ) + + scaler.scale(loss).backward() # Scale the loss + scaler.step(optimizer) # Step the optimizer + scaler.update() # Update the scaler + + if (batch_idx + 1) % print_inter == 0: + print( + f"[{e}, {batch_idx}] Loss: {loss.item():.4f}, Loss weight: {weight.item():.4f}" + ) + + if ( + (batch_idx + 1) % val_inter == 0 + and batch_idx != 0 + and batch_idx != len(train_loader) + ): + valid_eval_results = evaluate_better( + model, device, valid_loader, args, smalltest=False + ) + train_eval_results = evaluate_better( + model, device, train_loader, args, smalltest=True + ) + model.train() + print( + f"[{e}, {batch_idx}] Validation: {eval_print(valid_eval_results)}" + ) + print(f"[{e}, {batch_idx}] Train: {eval_print(train_eval_results)}") + + train_eval_results = evaluate_better(model, device, train_loader, args, smalltest=True) + epoch_eval_results = evaluate_better( + model, device, valid_loader, args, smalltest=False + ) + if verbal: + print(f"[Epoch {e+1}] [Train] {eval_print(train_eval_results)}") + print(f"[Epoch {e+1}] [Valid] {eval_print(epoch_eval_results)}\n") + + if epoch_eval_results["Recall"] > base_acc: + base_acc = epoch_eval_results["Recall"] + best_eval = epoch_eval_results + model.cpu() + best_model = copy.deepcopy(model.state_dict()) + model.to(device) + no_improve = 0 + else: + no_improve += 1 + + if no_improve >= early_stop_waiting or base_acc > 0.99: + break + + return best_model, best_eval diff --git a/HybridTensor/routers/mha/trainer_regression.py b/HybridTensor/routers/mha/trainer_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..c55de549345618b0f028eacabe993e143b268368 --- /dev/null +++ b/HybridTensor/routers/mha/trainer_regression.py @@ -0,0 +1,177 @@ + +import torch +import copy +import numpy as np +from tqdm import tqdm + +# from torch.cuda.amp import autocast, GradScaler +from torch.amp import GradScaler, autocast + +def eval_print(validation_results): + result = "" + for metric_name, metric_val in validation_results.items(): + result = f"{result}{metric_name}: {metric_val:.4f} " + return result + +def generate_topk_labels(norm, k): + """ + Generate multi-hot labels by selecting top k indices per sample. + """ + K = k + if K <= 0: + raise ValueError("k must be a positive integer.") + # Handle cases where k > num_heads + K = min(K, norm.shape[1]) + _, indices = norm.topk(K, dim=1) + one_hot = torch.zeros_like(norm).scatter_(1, indices, 1) + return one_hot + +def evaluate_regression(model, device, loader, args, smalltest=False): + """ + Evaluate the model using regression loss and classification metrics based on top k predictions. + """ + model.eval() + + eval_metrics = { + "Loss": [], + "Precision": [], + "Recall": [], + "F1 Score": [], + "Classifier Sparsity": [], + "True Sparsity": [], + } + mse_loss = torch.nn.MSELoss() + with torch.no_grad(): + for batch_idx, batch in enumerate(loader): + x, y = batch # x: [batch_size, model_dim], y: [batch_size, num_heads] + x = x.to(device) + y = y.float().to(device) + + logits = model(x) + preds = logits # [batch_size, num_heads] + + loss = mse_loss(preds, y) + eval_metrics["Loss"].append(loss.item()) + + # Generate top k labels for true and predictions + K = int(args.k * y.shape[1]) + K_pred = int((args.k+0.1) * y.shape[1]) + topk_true = generate_topk_labels(y, K) + topk_pred = generate_topk_labels(preds, K_pred) + + # Convert to integer for metric calculations + topk_true = topk_true.int() + topk_pred = topk_pred.int() + + # Compute precision, recall, and F1 per sample + TP = (topk_pred & topk_true).sum(dim=1).float() + FP = (topk_pred & (1 - topk_true)).sum(dim=1).float() + FN = ((1 - topk_pred) & topk_true).sum(dim=1).float() + + per_sample_precision = torch.where( + (TP + FP) > 0, + TP / (TP + FP), + torch.zeros_like(TP) + ) + per_sample_recall = torch.where( + (TP + FN) > 0, + TP / (TP + FN), + torch.zeros_like(TP) + ) + per_sample_f1 = torch.where( + (per_sample_precision + per_sample_recall) > 0, + 2 * per_sample_precision * per_sample_recall / (per_sample_precision + per_sample_recall), + torch.zeros_like(per_sample_precision) + ) + + eval_metrics["Precision"].append(per_sample_precision.mean().item()) + eval_metrics["Recall"].append(per_sample_recall.mean().item()) + eval_metrics["F1 Score"].append(per_sample_f1.mean().item()) + + # Sparsity metrics + # eval_metrics["True Sparsity"].append(topk_true.sum(dim=1).mean().item()) + # eval_metrics["Classifier Sparsity"].append(topk_pred.sum(dim=1).mean().item()) + + eval_metrics["True Sparsity"].append(topk_true.sum(dim=1).float().mean().item()) + eval_metrics["Classifier Sparsity"].append(topk_pred.sum(dim=1).float().mean().item()) + + + if smalltest and batch_idx >= 100: + break + + # Average metrics over all batches + for k, v in eval_metrics.items(): + eval_metrics[k] = np.mean(v) + + return eval_metrics + +def train_regression(model, train_loader, valid_loader, args, device, verbal=True): + """ + Train the MHA_Router model as a regression model to predict attention norms. + """ + early_stop_waiting = 2 + scaler = GradScaler() + # scaler = GradScaler(device_type='cuda') + model.to(device) + + optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=0.01) + mse_loss_fn = torch.nn.MSELoss() + + # Initial evaluation + eval_results = evaluate_regression(model, device, valid_loader, args, smalltest=True) + if verbal: + print(f"[Start] {eval_print(eval_results)}") + + best_model = copy.deepcopy(model.state_dict()) + base_recall = eval_results["Recall"] + best_eval = eval_results + no_improve = 0 + + for epoch in range(args.epochs): + model.train() + epoch_losses = [] + for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")): + x, y = batch + x = x.to(device) + y = y.float().to(device) + + optimizer.zero_grad() + + with autocast(device_type='cuda'): + preds = model(x) # [batch_size, num_heads] + loss = mse_loss_fn(preds, y) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + epoch_losses.append(loss.item()) + + if verbal and (batch_idx + 1) % 100 == 0: + # print(f"[Epoch {epoch+1}, Batch {batch_idx+1}] Loss: {loss.item():.4f}") + pass + + # End of epoch evaluation + train_eval_results = evaluate_regression(model, device, train_loader, args, smalltest=True) + valid_eval_results = evaluate_regression(model, device, valid_loader, args, smalltest=False) + + if verbal: + print(f"[Epoch {epoch+1}] [Train] {eval_print(train_eval_results)}") + print(f"[Epoch {epoch+1}] [Valid] {eval_print(valid_eval_results)}\n") + + # Check for improvement + if valid_eval_results["Recall"] > base_recall: + base_recall = valid_eval_results["Recall"] + best_eval = valid_eval_results + best_model = copy.deepcopy(model.state_dict()) + no_improve = 0 + else: + no_improve += 1 + + # Early stopping + if no_improve >= early_stop_waiting or base_recall > 0.99: + if verbal: + print(f"Early stopping triggered at epoch {epoch+1}.") + break + + return best_model, best_eval \ No newline at end of file diff --git a/HybridTensor/routers/mlp/MLP_router.py b/HybridTensor/routers/mlp/MLP_router.py new file mode 100644 index 0000000000000000000000000000000000000000..c7d0a6cc7c00ec9b78c0d66ce0b6fe1a34f1ba3b --- /dev/null +++ b/HybridTensor/routers/mlp/MLP_router.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from HybridTensor.routers.router_utils import CONFIG + + +class Router(torch.nn.Module): + def __init__(self, model_dim, inner_dim = 1024, out_dim = None, norm="layer", hidden_activation="none"): + super(Router, self).__init__() + self.fc1 = torch.nn.Linear(model_dim, inner_dim, bias=None) + self.use_act = True + + if norm == 'batch': + self.norm = torch.nn.BatchNorm1d(inner_dim) + elif norm == 'layer': + self.norm = torch.nn.LayerNorm(inner_dim) + else: + self.norm = nn.Identity() + + self.activation_fn(hidden_activation) + + if out_dim is None: + out_dim = model_dim * 4 + self.fc2 = torch.nn.Linear(inner_dim, out_dim, bias=None) + + def activation_fn(self, hidden_activation): + if hidden_activation == "relu": + self.activation = torch.nn.ReLU() + elif hidden_activation == 'relu_squared': + self.activation = ReLUSquared() + elif hidden_activation == 'leaky_relu': + self.activation = nn.LeakyReLU(negative_slope=0.01) + elif hidden_activation == 'swish': + self.activation = nn.SiLU() + elif hidden_activation == 'elu': + self.activation = nn.ELU(alpha=1.0) + elif hidden_activation == 'gelu': + self.activation = nn.GELU() + elif hidden_activation == 'selu': + self.activation = nn.SELU() + elif hidden_activation == 'mish': + self.activation = Mish() + elif hidden_activation == 'none': + self.activation = nn.Identity() + + def use_dropout(self, args): + if args.dropout > 0: + self.use_dropout = True + self.dropout = torch.nn.Dropout(p=args.dropout) + else: + self.use_dropout = False + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.norm(x) + x = self.fc2(x) + + return x + + +# Custom activation functions +class ReLUSquared(nn.Module): + def forward(self, x): + return F.relu(x) ** 2 + +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + +class SwiGLU(nn.Module): + def __init__(self, dim): + super(SwiGLU, self).__init__() + self.dim = dim + self.activation = nn.SiLU() + + def forward(self, x): + x1, x2 = x.chunk(2, dim=-1) + return self.activation(x1) * x2 + +class Mish(nn.Module): + def forward(self, x): + return x * torch.tanh(F.softplus(x)) \ No newline at end of file diff --git a/HybridTensor/routers/mlp/README.md b/HybridTensor/routers/mlp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c2ae771ddfd6b501fb3fff1e6d47dafb81bed9da --- /dev/null +++ b/HybridTensor/routers/mlp/README.md @@ -0,0 +1,10 @@ +# To run the mlp router training use the following scripts + +for a single layer + +python -m HybridTensor.routers.mlp.main_mlp --model_index 5 --L 4 --data_dir --ckpt_dir --gpu 1 + + +for all the layers + +./HybridTensor/routers/mlp/train_mlp_routers.sh \ No newline at end of file diff --git a/HybridTensor/routers/mlp/__pycache__/mlp_router_optim.cpython-39.pyc b/HybridTensor/routers/mlp/__pycache__/mlp_router_optim.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ee6ee018dcd174cbafcea244d1d57ce415d43a1 Binary files /dev/null and b/HybridTensor/routers/mlp/__pycache__/mlp_router_optim.cpython-39.pyc differ diff --git a/HybridTensor/routers/mlp/main_mlp.py b/HybridTensor/routers/mlp/main_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..ef27abd941e880b5f024987c95353eadec7ececd --- /dev/null +++ b/HybridTensor/routers/mlp/main_mlp.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import numpy as np +import argparse +import random +import os +from HybridTensor.routers.mlp.trainer_mlp import train +from HybridTensor.routers.router_utils import DATA, MODEL_CHOICES, DATA_CHOICES, CONFIG, BasicDataset +from HybridTensor.routers.router_utils import get_data, create_dataset, create_log_path, augment_data, get_data_ +from HybridTensor.routers.router_utils import generate_experiment_id, create_date_directory, generate_log_filename, setup_logging, generate_model_filename +import logging +from HybridTensor.routers.mlp.MLP_router import Router + +from HybridTensor.utils.utils import _get_device, extract_model_name +from HybridTensor.utils.activations import OPT_CONFIGS, OPT_MODELS + +def arg_parser(): + parser = argparse.ArgumentParser(description="PyTorch OPT Full Model") + parser.add_argument("--model", type=str, default="6_7b", choices=MODEL_CHOICES) + parser.add_argument("--model_name", type=str, default="opt", help="model name") + parser.add_argument("--model_index", type=int, default=8, help="model index") + parser.add_argument("--dataset", type=str, default="c4", choices=DATA_CHOICES) + parser.add_argument("--L", type=int, default=0, help="which layer") + parser.add_argument("--D", type=int, default=1024, help="low rank dimension") + parser.add_argument("--batch_size", type=int, default=64, help="batch size") # a smaller batch size results in better precision and recall of the model + parser.add_argument("--norm", type=str, default="none", choices=["batch", "layer", "none"]) + parser.add_argument("--epochs", type=int, default=20, help="epochs") + parser.add_argument("--lr", type=float, default=0.0001, help="learning rate") + parser.add_argument("--ckpt_dir", type=str, default="", help="checkpoint directory") # add a argument for checkpoint dir + parser.add_argument("--data_dir", type=str, default="", help="data directory") # add a argument for data dir + parser.add_argument("--gpu", type=int, default=0, help="which gpu to use") # add a argument for which gpu to use + parser.add_argument("--max_samples", type=int, default=0, help="total samples to process") # add a argument for total samples + parser.add_argument("--dropout", type=float, default=0) + parser.add_argument("--alpha", type=float, default=1) + parser.add_argument("--gamma", type=float, default=1) + parser.add_argument("--threshold", type=float, default=0.5) + parser.add_argument("--data_augmentation", type=int, default=0) # use cold neurons for routing + parser.add_argument("--loss_fn", type=str, default="bce", choices=["bce", "focal", "mse"]) + parser.add_argument("--hidden_activation", type=str, default="none", choices=["relu", "gelu", "relu_squared", "leaky_relu", "swish", "swiglu", "elu", "selu", "mish", "none"]) + parser.add_argument("--debug", type=bool, default=False) + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + + args = arg_parser() + model_name = OPT_MODELS[args.model_index -1] + model_config = OPT_CONFIGS[model_name] + model_dim = model_config['d'] + inner_dim = args.D + device = _get_device(args.gpu) + + # Generate a unique experiment ID, date directory, and log file path + if args.debug: + experiment_id = generate_experiment_id() + date_dir = create_date_directory(args.ckpt_dir, args.model) # This will create {base_dir}/{YYYY-MM-DD}/ + log_file_path = generate_log_filename(experiment_id, args.model_name, args.L, date_dir) + else: + date_dir = "None" + experiment_id = "None" + model_name_clean = extract_model_name(model_name) + log_file_path = f"{args.ckpt_dir}/{model_name_clean}-routers/mlp/logs/{args.model_name}-{args.L}.log" + ckpt_path = f"{args.ckpt_dir}/{model_name_clean}-routers/mlp" + + # Ensure the logs directory exists + log_dir = os.path.dirname(log_file_path) + os.makedirs(log_dir, exist_ok=True) + + # Initialize logging + setup_logging(log_file_path) + + # Log the experiment ID and setup details + logging.info(f"Experiment ID: {experiment_id}") + logging.info(f"Model Name: {args.model_name}") + logging.info(f"Layer Index: {args.L}") + logging.info(f"Checkpoint Directory: {args.ckpt_dir}") + logging.info(f"Date Directory: {date_dir}") + logging.info(f"Using Dropout: {args.dropout}") + logging.info(f"loss_fn: {args.loss_fn}") + logging.info(f"hidden_activation: {args.hidden_activation}") + logging.info(f"Learning Rate: {args.lr}") + + random.seed(0) + torch.manual_seed(0) + np.random.seed(0) + + device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") + + print("=" * 40, "Layer", args.L, "=" * 40) + # query, labels = get_data(args, args.L) # Deja vu style data loading + if args.max_samples == 0: + total_samples = 400000 + else: + total_samples = args.max_samples + + query, labels = get_data_(args.data_dir, args.L, data_type = 'mlp_activations', total_samples=total_samples) # New data loading style + + + if args.data_augmentation: # hasn't been helpful yet + logging.info("Using Data Augmentation") + index_metadata, cold_labels = augment_data(labels, device=device) + hot_index, cold_index, pruned_index = index_metadata # need to store this meta data to storage + cold_neurons = len(cold_index) + train_loader, test_loader = create_dataset(query, cold_labels, args) + else: + logging.info("Not Using Data Augmentation") + train_loader, test_loader = create_dataset(query, labels, args) + cold_neurons = 0 + + # Initialize the model + # router = Router(args, cold_neurons=cold_neurons).to(device) + mlp_router = Router(model_dim, inner_dim, norm=args.norm, hidden_activation=args.hidden_activation) + + # Log the model architecture + logging.info("Model Architecture:\n{}".format(mlp_router)) + logging.info("Training Arguments:\n{}".format(args)) + + best_model, eval_result = train( + mlp_router, train_loader, test_loader, args, device, verbal=True + ) + # Generate model filename based on evaluation results + if args.debug: + model_filename = generate_model_filename(experiment_id, args.model_name, args.L, eval_result) + model_file_path = date_dir / model_filename + else: + model_file_path = f"{ckpt_path}/mlp_router_{args.L}-{eval_result['Recall']:.2f}-{eval_result['Precision']:.2f}-{eval_result['Classifier Sparsity']:.2f}.pt" + + # Save the best model + torch.save(best_model, model_file_path) + logging.info(f"Model saved at {model_file_path}") diff --git a/HybridTensor/routers/mlp/mlp_router_optim.py b/HybridTensor/routers/mlp/mlp_router_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..c7e28d8a217f6765ca109cc754ae4c875c029530 --- /dev/null +++ b/HybridTensor/routers/mlp/mlp_router_optim.py @@ -0,0 +1,326 @@ +# python -m HybridTensor.routers.mlp.mlp_router_optim --model_index 5 --batch_size --mlp_ckpt_dir --act_data_dir + +from HybridTensor.routers.router_utils import get_data_ +from HybridTensor.modules.SelectiveMLP import MLPRouter +from HybridTensor.utils.activations import MODELS, CONFIGS +from HybridTensor.utils.activations import build_mlp_topk_lookup +from HybridTensor.utils.utils import extract_model_name +import os +import torch +import argparse +import re +import numpy as np +import csv +from tqdm import tqdm + +def save_dict_to_csv(data, batch_size, model_name, results_dir): + """ + Saves a dictionary to a CSV file with columns 'layer', 'batch_size', and 'optimal_k'. + A new directory named {model_name} is created under results_dir, and the file is saved there as + mlp_router_configs_bsize_{batch_size}.csv. + + Args: + data (dict): Dictionary with layer index (int) as key and optimal k (int) as value. + batch_size (int): Fixed batch size for all layers. + model_name (str): Model name used for directory creation. + results_dir (str): Directory where the model-specific folder will be created. + """ + # Create new directory for the model inside results_dir + model_name = extract_model_name(model_name) + model_dir = os.path.join(results_dir, model_name) + os.makedirs(model_dir, exist_ok=True) + + # Define the filename and complete path + filename = f"mlp_router_configs_bsize_{batch_size}.csv" + file_path = os.path.join(model_dir, filename) + + # Write CSV file + with open(file_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['layer', 'batch_size', 'optimal_k']) + for layer, optimal_k in data.items(): + writer.writerow([layer, batch_size, optimal_k]) + + print(f"File saved as {file_path}") + +def load_router_dict_from_csv(file_path: str, batch_size: int) -> dict: + """ + Loads a dictionary from a CSV file with columns 'layer', 'batch_size', and 'optimal_k'. + The batch_size column is ignored; only the mapping of layer -> optimal_k is returned. + + Args: + file_path (str): Full path to the input CSV file. + + Returns: + dict: Dictionary mapping layer index (int) to optimal k (int). + """ + file_name = f"mlp_router_configs_bsize_{batch_size}.csv" + file_path = os.path.join(file_path, file_name) + + data = {} + with open(file_path, newline='') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + layer = int(row['layer']) + optimal_k = int(row['optimal_k']) + data[layer] = optimal_k + return data + +def build_default_mlp_topk_lookup(num_layers, default_topk = 512): + """ + Creates a default lookup table for MLP top-k values. + + The lookup maps each layer to a default top-k value. + + Parameters: + num_layers (int): The number of layers in the model. + default_topk (int): The default top-k value to assign to each layer. + Returns: + dict: A mapping from layer id to the default top-k value. + """ + + mlp_lookup = {i: default_topk for i in range(num_layers)} + return mlp_lookup + +def load_mlp_router(layer_id, directory, embed_dim, low_rank_dim, out_dim, act_th, device=None, dtype=None): + """ + Load the MLP router for a specific layer. + + Args: + layer_id (int): The ID of the layer. + directory (str): Path to the directory containing the router checkpoint files. + embed_dim (int): Input embedding dimension. + low_rank_dim (int): Low-rank intermediate dimension. + out_dim (int): Number of neurons. + act_th (float): Activation threshold. + device: Device to load the model onto. + dtype: Data type. + + Returns: + MLPRouter: Loaded MLP router model. + """ + pattern = re.compile(rf"mlp_router_{layer_id}-.*\.pt") + files = os.listdir(directory) + router_file = next((f for f in files if pattern.match(f)), None) + + if router_file is None: + raise FileNotFoundError(f"No router file found for layer {layer_id} in {directory}") + + router_path = os.path.join(directory, router_file) + + router = MLPRouter(embed_dim, low_rank_dim, out_dim, act_th, device=device, dtype=dtype) + state_dict = torch.load(router_path, map_location=device, weights_only=True) + router.load_state_dict(state_dict) + + return router + +def compute_recall(router_output, ground_truth, k, eps=1e-8): + """ + Computes recall for router predictions. + + Args: + router_output (torch.Tensor): Router's output of shape (B, N). + ground_truth (torch.Tensor): True activations of shape (B, N), where values > 0 indicate active neurons. + k (int): Number of top neurons to select per sample. + eps (float): Small value to avoid division by zero. + + Returns: + tuple: (precision, recall) averaged over the batch. + """ + # Create a binary mask for ground truth activations + gt_mask = (ground_truth > 0).float() + + # Get top-k indices per sample from router output + topk_indices = torch.topk(router_output, k, dim=1).indices + + # Create prediction mask: mark top-k as active (1) + pred_mask = torch.zeros_like(router_output) + pred_mask.scatter_(1, topk_indices, 1) + + # Compute true positives per sample + true_positives = (pred_mask * gt_mask).sum(dim=1) + + # Precision: fraction of top-k that are correct + # precision = true_positives / k + + # Recall: fraction of ground truth active neurons that are predicted + recall = true_positives / (gt_mask.sum(dim=1) + eps) + + # return precision.mean().item(), recall.mean().item() + return recall.mean().item() + + +def _optimize_layer(args, layer_idx, starting_k = 128, target_recall = 0.99, delta = 100, min_k = 1024, max_k = 4096): + ''' + This function optimizes the router for a single layer. + Given the hidden states and true activations, it uses the model to predict the activations. + The top-k values of the router outputs are used as the predictions for the activations. + It then computes the precision and recall of the predictions. + While the recall is less than the target recall, the model is optimized. + + ''' + + # load the router for the layer + mlp_router_layer = load_mlp_router( + layer_id=layer_idx, + directory= args.mlp_ckpt_dir, + embed_dim=args.embed_dim, + low_rank_dim=1024, + out_dim=args.total_neurons, + act_th=0.5, + device="cuda", + dtype=torch.float16, + ) + + hidden_states, true_activations = get_data_(args.act_data_dir, + layer_idx=layer_idx, + data_type='mlp_activations', + total_neurons=args.total_neurons, + total_samples=args.total_samples) + + hidden_states = torch.from_numpy(hidden_states).to("cuda") + hidden_states.requires_grad = False + true_activations = torch.from_numpy(true_activations).to("cuda") + true_activations.requires_grad = False + + current_recall = 0 + k = starting_k + + while current_recall < target_recall: + + with torch.no_grad(): + predictions = mlp_router_layer(hidden_states) + recall = compute_recall(predictions, true_activations, k=k) + if recall < target_recall: + k = k + delta + current_recall = recall + # print(f"Layer {layer_idx}: Recall = {recall:.4f} at k = {k}") + + k = int(np.ceil(k /128) * 128) + k = min(max_k, max(min_k, k)) + return k + +def _optimize_router(args): + starting_topk_layer = build_mlp_topk_lookup(data_path=args.stats_dir, batch_size=args.batch_size, delta=0) + total_layers = args.num_layers + optim_k = {} + + for layer_idx in tqdm(range(total_layers), desc="Optimizing MLP Router"): + optim_k[layer_idx] = _optimize_layer(args, layer_idx, starting_k = starting_topk_layer[layer_idx], + target_recall=args.target_recall, delta=args.delta, + min_k=args.min_k, max_k=args.max_k + ) + + # save the optimized topk values in a csv file + save_dict_to_csv(optim_k, args.batch_size, args.model_name, args.results_dir) + + +def compute_recall_batch(router_output, ground_truth, k, eps=1e-8): + # Ground truth: if any sample in the batch is active, mark neuron as active. + gt_sum = (ground_truth > 0).float().sum(dim=0) + gt_mask = (gt_sum >= 1).float() + # Prediction: zero-out negatives, sum over batch, then select top-k neurons. + neurons_nonzero = torch.nn.ReLU()(router_output) + pred_sum = neurons_nonzero.sum(dim=0) + _, topk_indices = torch.topk(pred_sum, k, sorted=False) + pred_mask = torch.zeros_like(pred_sum) + pred_mask[topk_indices] = 1.0 + true_positives = (pred_mask * gt_mask).sum() + recall = true_positives / (gt_mask.sum() + eps) + return recall.item() + +def _optimize_layer_batched(args, layer_idx, starting_k=128, target_recall=0.99, delta=100, min_k=512, max_k=4096): + mlp_router_layer = load_mlp_router( + layer_id=layer_idx, + directory=args.mlp_ckpt_dir, + embed_dim=args.embed_dim, + low_rank_dim=1024, + out_dim=args.total_neurons, + act_th=0.5, + device="cuda", + dtype=torch.float16, + ) + + hidden_states, true_activations = get_data_(args.act_data_dir, + layer_idx=layer_idx, + data_type='mlp_activations', + total_neurons=args.total_neurons, + total_samples=args.total_samples) + hidden_states = torch.from_numpy(hidden_states).to("cuda") + true_activations = torch.from_numpy(true_activations).to("cuda") + + total_samples = hidden_states.shape[0] + batch_size = args.batch_size + num_batches = total_samples // batch_size + k = starting_k + current_recall = 0 + + while current_recall < target_recall: + batch_recalls = [] + indices = torch.randperm(total_samples) + for i in range(num_batches): + batch_idx = indices[i*batch_size:(i+1)*batch_size] + hs_batch = hidden_states[batch_idx] + gt_batch = true_activations[batch_idx] + with torch.no_grad(): + predictions = mlp_router_layer(hs_batch) + batch_recall = compute_recall_batch(predictions, gt_batch, k) + batch_recalls.append(batch_recall) + current_recall = sum(batch_recalls) / len(batch_recalls) + if current_recall < target_recall: + k += delta + k = int(np.ceil(k / 128) * 128) + # k = min(max_k, max(min_k, k)) + return k + +def _optimize_router_batched(args): + if args.stats_dir is not None: + starting_topk_layer = build_mlp_topk_lookup(data_path=args.stats_dir, batch_size=args.batch_size, delta=0) + else: + starting_topk_layer = build_default_mlp_topk_lookup(args.num_layers, default_topk=args.min_k) + + total_layers = args.num_layers + optim_k = {} + for layer_idx in tqdm(range(total_layers), desc="Optimizing MLP Router (Batched)"): + optim_k[layer_idx] = _optimize_layer_batched(args, layer_idx, starting_k=starting_topk_layer[layer_idx], + target_recall=args.target_recall, delta=args.delta, + min_k=args.min_k, max_k=args.max_k) + save_dict_to_csv(optim_k, args.batch_size, args.model_name, args.results_dir) + + +def arg_parser(): + parser = argparse.ArgumentParser(description=' MLP Router Optimizations') + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--model_index', type=int, default=5) + parser.add_argument('--results_dir', type=str, default='configs/mlp_router') + parser.add_argument('--stats_dir', type=str, default=None) + parser.add_argument('--mlp_ckpt_dir', type=str, default='') + parser.add_argument('--act_data_dir', type=str, default='') + parser.add_argument('--max_batch_size', type=int, default=32) + parser.add_argument('--total_samples', type=int, default=50000) + parser.add_argument('--target_recall', type=float, default=0.99) + parser.add_argument('--delta', type=int, default=100) + parser.add_argument('--min_k', type=int, default=512) + parser.add_argument('--max_k', type=int, default=8192) + parser.add_argument('--device', type=int, default=0) + parser.add_argument('--mode', type=str, default='row', choices=['row', 'col', 'auto']) + + return parser.parse_args() + + +if __name__ == "__main__": + args = arg_parser() + + # model config + model_name = MODELS[args.model_index-1] + config = CONFIGS[model_name] + args.model_name = model_name + args.embed_dim = config['d'] + args.total_neurons = config['neurons'] + args.num_layers = config['num_layer'] + + if args.batch_size == 1: + _optimize_router(args) + else: + _optimize_router_batched(args) + \ No newline at end of file diff --git a/HybridTensor/routers/mlp/mlp_router_optim_fast.py b/HybridTensor/routers/mlp/mlp_router_optim_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..011b83e49fdf40b1b28e0a74d8c1150c18401c57 --- /dev/null +++ b/HybridTensor/routers/mlp/mlp_router_optim_fast.py @@ -0,0 +1,148 @@ +# python -m HybridTensor.routers.mlp.mlp_router_optim_fast --model_index 5 --batch_size 256 --delta 200 --device 0 +from HybridTensor.routers.router_utils import get_data_ +from HybridTensor.modules.SelectiveMLP import MLPRouter +from HybridTensor.utils.activations import MODELS, CONFIGS +from HybridTensor.utils.activations import build_mlp_topk_lookup +from HybridTensor.utils.utils import extract_model_name +from HybridTensor.routers.mlp.mlp_router_optim import save_dict_to_csv, load_router_dict_from_csv, load_mlp_router, _optimize_layer, _optimize_router +from HybridTensor.utils.utils import _get_device + +import os +import torch +import argparse +import re +import numpy as np +import csv +from tqdm import tqdm + +# --- New vectorized batch recall function --- +def vectorized_compute_recall_batch(predictions_batches, gt_batches, k, eps=1e-8): + # Aggregate ground truth: if any sample in the batch is active, mark neuron as active. + gt_mask = ((gt_batches > 0).float().sum(dim=1) >= 1).float() # shape: (num_batches, num_neurons) + # Sum predictions over batch dimension (predictions are already ReLU-ed) + pred_sum = predictions_batches.sum(dim=1) # shape: (num_batches, num_neurons) + topk_indices = pred_sum.topk(k, dim=1).indices # shape: (num_batches, k) + pred_mask = torch.zeros_like(pred_sum) + pred_mask.scatter_(1, topk_indices, 1) + true_positives = (pred_mask * gt_mask).sum(dim=1) + batch_recall = true_positives / (gt_mask.sum(dim=1) + eps) + return batch_recall.mean().item() + +def build_default_mlp_topk_lookup(num_layers, default_topk = 512): + """ + Creates a default lookup table for MLP top-k values. + + The lookup maps each layer to a default top-k value. + + Parameters: + num_layers (int): The number of layers in the model. + default_topk (int): The default top-k value to assign to each layer. + Returns: + dict: A mapping from layer id to the default top-k value. + """ + + mlp_lookup = {i: default_topk for i in range(num_layers)} + return mlp_lookup + +def _optimize_layer_batched(args, layer_idx, starting_k=128, target_recall=0.99, delta=100, min_k=512, max_k=4096): + device = _get_device(args.device) + + mlp_router_layer = load_mlp_router( + layer_id=layer_idx, + directory=args.mlp_ckpt_dir, + embed_dim=args.embed_dim, + low_rank_dim=1024, + out_dim=args.total_neurons, + act_th=0.5, + device=device, + dtype=torch.float16, + ) + + + + hidden_states, true_activations = get_data_(args.act_data_dir, + layer_idx=layer_idx, + data_type='mlp_activations', + total_neurons=args.total_neurons, + total_samples=args.total_samples) + hidden_states = torch.from_numpy(hidden_states).to(device) + true_activations = torch.from_numpy(true_activations).to(device) + + with torch.no_grad(): + predictions_all = mlp_router_layer(hidden_states) + predictions_all = torch.nn.functional.relu(predictions_all) + + total_samples = hidden_states.shape[0] + batch_size = args.batch_size + num_batches = total_samples // batch_size + k = starting_k + max_topk = starting_k + args.lambda_param # New maximum topk constraint + current_recall = 0.0 + + while current_recall < target_recall and k < max_topk: + indices = torch.randperm(total_samples) + predictions_shuffled = predictions_all[indices] + true_activations_shuffled = true_activations[indices] + trimmed = num_batches * batch_size + predictions_batches = predictions_shuffled[:trimmed].view(num_batches, batch_size, -1) + gt_batches = true_activations_shuffled[:trimmed].view(num_batches, batch_size, -1) + current_recall = vectorized_compute_recall_batch(predictions_batches, gt_batches, k) + if current_recall < target_recall: + k += delta + if k >= max_topk: + k = max_topk + break + k = int(np.ceil(k / 128) * 128) + print(f"Layer {layer_idx}: k = {k}, recall = {current_recall}") + # k = min(max_k, max(min_k, k)) + return k + +def _optimize_router_batched(args): + # starting_topk_layer = build_mlp_topk_lookup(data_path=args.stats_dir, batch_size=args.batch_size, delta=0) + if args.stats_dir is not None: + starting_topk_layer = build_mlp_topk_lookup(data_path=args.stats_dir, batch_size=args.batch_size, delta=0) + else: + starting_topk_layer = build_default_mlp_topk_lookup(args.num_layers, default_topk=args.min_k) + total_layers = args.num_layers + optim_k = {} + for layer_idx in tqdm(range(total_layers), desc="Optimizing MLP Router (Batched)"): + optim_k[layer_idx] = _optimize_layer_batched(args, layer_idx, starting_k=starting_topk_layer[layer_idx], + target_recall=args.target_recall, delta=args.delta, + min_k=args.min_k, max_k=args.max_k) + + save_dict_to_csv(optim_k, args.batch_size, args.model_name, args.results_dir) + +def arg_parser(): + parser = argparse.ArgumentParser(description='MLP Router Optimizations') + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--model_index', type=int, default=5) + parser.add_argument('--results_dir', type=str, default='configs/mlp_router') + parser.add_argument('--stats_dir', type=str, default=None) + parser.add_argument('--mlp_ckpt_dir', type=str, default='') + parser.add_argument('--act_data_dir', type=str, default='') + parser.add_argument('--max_batch_size', type=int, default=32) + parser.add_argument('--total_samples', type=int, default=50000) + parser.add_argument('--target_recall', type=float, default=0.99) + parser.add_argument('--delta', type=int, default=100) + parser.add_argument('--lambda_param', type=int, default=2048, + help='Maximum extra topk offset allowed (max_topk = starting_k + lambda_param)') + parser.add_argument('--min_k', type=int, default=512) + parser.add_argument('--max_k', type=int, default=4096) + parser.add_argument('--device', type=int, default=0) + parser.add_argument('--mode', type=str, default='row', choices=['row', 'col', 'auto']) + return parser.parse_args() + +if __name__ == "__main__": + args = arg_parser() + + model_name = MODELS[args.model_index-1] + config = CONFIGS[model_name] + args.model_name = model_name + args.embed_dim = config['d'] + args.total_neurons = config['neurons'] + args.num_layers = config['num_layer'] + + if args.batch_size == 1: + _optimize_router(args) + else: + _optimize_router_batched(args) diff --git a/HybridTensor/routers/mlp/router_batch_optim.sh b/HybridTensor/routers/mlp/router_batch_optim.sh new file mode 100644 index 0000000000000000000000000000000000000000..bff25dbd3043d50304ef08c6b7ad43419559ea54 --- /dev/null +++ b/HybridTensor/routers/mlp/router_batch_optim.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Define parameters +# model_index=7 +# stats_dir="results/mlp_results/batch_activations/opt-30b/" +# mlp_ckpt_dir="/pscratch/sd/s//HybridTensor/checkpoint/opt-30b-routers/mlp/" +# act_data_dir="/pscratch/sd/s//HybridTensor/opt-30b_act_data/" + +model_index=8 +stats_dir="results/mlp_results/batch_activations/opt-66b/" +mlp_ckpt_dir="/pscratch/sd/s//HybridTensor/checkpoint/opt-66b-routers/mlp/" +act_data_dir="/pscratch/sd/s//HybridTensor/opt-66b_act_data/" + +# Array of batch sizes to test (adjust as needed) +batch_sizes=(48 64) + +# Concurrency control +max_concurrent=4 +current_jobs=0 +counter=2 + +for batch_size in "${batch_sizes[@]}"; do + # Choose GPU in a round-robin manner (assuming 4 GPUs) + gpu=$(( counter % 4 )) + + # Run the command in the background + python -m HybridTensor.routers.mlp.mlp_router_optim_fast \ + --model_index "$model_index" \ + --stats_dir "$stats_dir" \ + --mlp_ckpt_dir "$mlp_ckpt_dir" \ + --act_data_dir "$act_data_dir" \ + --batch_size "$batch_size" \ + --delta 200 \ + --device "$gpu" & + + current_jobs=$(( current_jobs + 1 )) + counter=$(( counter + 1 )) + + # Check if we've reached the maximum number of concurrent jobs + if [ "$current_jobs" -ge "$max_concurrent" ]; then + wait + current_jobs=0 + fi +done + +# Final wait for any remaining background jobs +wait + +echo "All tests completed." \ No newline at end of file diff --git a/HybridTensor/routers/mlp/train_mlp_routers.sh b/HybridTensor/routers/mlp/train_mlp_routers.sh new file mode 100644 index 0000000000000000000000000000000000000000..5924120a5b0f4537c809a8a1b75a63630f1dbaf9 --- /dev/null +++ b/HybridTensor/routers/mlp/train_mlp_routers.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +# Number of GPUs available +NUM_GPUS=2 + +# Each GPU can handle JOBS_PER_GPU concurrent trainings +JOBS_PER_GPU=1 + +# Total layers you want to process +START_LAYER=16 +END_LAYER=31 + + +MODEL_INDEX=5 +D=1024 +DATA_DIR="/home/grads/s//nvme/HybridTensor/opt-6.7b_act_data/" +CKPT_DIR="/home/grads/s//nvme/HybridTensor/checkpoint" + +BATCH_SIZE=$((NUM_GPUS * JOBS_PER_GPU)) + +layer=$START_LAYER +while [ $layer -le $END_LAYER ]; do + # Launch a batch of jobs + for (( i=0; i<$BATCH_SIZE && layer<=$END_LAYER; i++ )); do + GPU_ID=$(( i / JOBS_PER_GPU )) + CUDA_VISIBLE_DEVICES=$GPU_ID python -m HybridTensor.routers.mlp.main_mlp\ + --model_index $MODEL_INDEX \ + --L $layer \ + --D $D \ + --data_dir $DATA_DIR \ + --ckpt_dir $CKPT_DIR & + layer=$((layer+1)) + done + + # Wait for this batch to finish + wait +done + +echo "All training jobs finished!" \ No newline at end of file diff --git a/HybridTensor/routers/mlp/trainer_mlp.py b/HybridTensor/routers/mlp/trainer_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..fa22dad8d40e3b672d9681f0a880b9264985d107 --- /dev/null +++ b/HybridTensor/routers/mlp/trainer_mlp.py @@ -0,0 +1,220 @@ +import torch +import copy +import numpy as np +from tqdm import tqdm +import logging + +class FocalLoss(torch.nn.Module): + def __init__(self, alpha=1, gamma=2): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + + def forward(self, logits, targets): + BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction='none') + pt = torch.exp(-BCE_loss) + F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss + return F_loss.mean() + +class LossFunction(torch.nn.Module): + def __init__(self, args, device='cuda:0'): + super(LossFunction, self).__init__() + self.loss_fn_str = args.loss_fn + if args.loss_fn == "bce": + self.loss_fn = torch.nn.functional.binary_cross_entropy + elif args.loss_fn == "focal": + self.loss_fn = FocalLoss(alpha=args.alpha, gamma=args.gamma) + elif args.loss_fn == "mse": + self.loss_fn = torch.nn.functional.mse_loss + + def forward(self, logits, targets): + if self.loss_fn_str == "bce": + probs = logits.sigmoid() + weight = (targets.sum() / targets.numel()) + 0.005 + loss_weight = targets * (1 - weight) + weight + loss = self.loss_fn( + probs, targets, weight=loss_weight + ) + return loss + else: + return self.loss_fn(logits, targets) + +def eval_print(validation_results): + result = "" + for metric_name, metirc_val in validation_results.items(): + result = f"{result}{metric_name}: {metirc_val:.4f} " + return result + + +def generate_label(y): + # positive + one_hot = (y > 0).to(y.dtype) + return one_hot + + +def evaluate_better(model, device, loader, args, smalltest=False): + model.eval() + + eval = { + "Loss": [], + "Loss Weight": [], + "Classifier Sparsity": [], + "True Sparsity": [], + } + + total_TP = 0 # True Positives + total_FP = 0 # False Positives + total_FN = 0 # False Negatives + total_TN = 0 # True Negatives + + th = args.threshold + + with torch.no_grad(): + for batch_idx, batch in enumerate(loader): + x, y = batch + x, y = x.to(device), y.to(device) + if args.loss_fn != "mse": + y = generate_label(y) + + logits = model(x) + if args.loss_fn == "mse": + preds = (logits > th).float() + else: + probs = logits.sigmoid() + preds = (probs >= th).float() + + if args.loss_fn == "mse": + loss = torch.nn.functional.mse_loss(preds, y) + else: + # Compute loss with weighting + weight = (y.sum() / y.numel()) + 0.005 + loss_weight = y * (1 - weight) + weight + eval["Loss Weight"].append(weight.item()) + loss = torch.nn.functional.binary_cross_entropy( + probs, y, weight=loss_weight + ) + eval["Loss"].append(loss.item()) + + # Compute sparsity metrics + eval["True Sparsity"].append(y.sum(dim=1).float().mean().item()) + eval["Classifier Sparsity"].append(preds.sum(dim=1).float().mean().item()) + + # Flatten predictions and labels for metric computation + preds_flat = preds.view(-1) + y_flat = y.view(-1) + + # Calculate True Positives, False Positives, False Negatives, True Negatives + TP = ((preds_flat == 1) & (y_flat == 1)).sum().item() + FP = ((preds_flat == 1) & (y_flat == 0)).sum().item() + FN = ((preds_flat == 0) & (y_flat == 1)).sum().item() + TN = ((preds_flat == 0) & (y_flat == 0)).sum().item() + + total_TP += TP + total_FP += FP + total_FN += FN + total_TN += TN + + if batch_idx >= 100 and smalltest: + break + + # Average loss and sparsity metrics over all batches + for k in ["Loss", "Loss Weight", "Classifier Sparsity", "True Sparsity"]: + eval[k] = np.mean(eval[k]) + + # Compute Precision, Recall, and F1 Score + if total_TP + total_FP > 0: + precision = total_TP / (total_TP + total_FP) + else: + precision = 0.0 + + if total_TP + total_FN > 0: + recall = total_TP / (total_TP + total_FN) + else: + recall = 0.0 + + if precision + recall > 0: + f1_score = 2 * (precision * recall) / (precision + recall) + else: + f1_score = 0.0 + + eval["Precision"] = precision + eval["Recall"] = recall + eval["F1 Score"] = f1_score + + return eval + +def evaluate_thresholds(model, device, loader, args): + thresholds = [0.5, 0.4, 0.3, 0.2] + for th in thresholds: + eval_results = evaluate_better(model, device, loader, args, smalltest=True) + print(f"[Threshold {th}] {eval_print(eval_results)}") + + +def train(model, train_loader, valid_loader, args, device, verbal=True): + num_val = 0 + early_stop_waiting = 5 + val_inter = len(train_loader) // (num_val + 1) + 1 + # num_print = 0 + # print_inter = len(train_loader) // (num_print + 1) + 1 + model.to(device) + + optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=0.01) # 0.01 + eval_results = evaluate_better(model, device, valid_loader, args, smalltest=True) + + if verbal: + # print(f"[Start] {eval_print(eval_results)}") + logging.info(f"[Start] {eval_print(eval_results)}") + + best_model = copy.deepcopy(model.state_dict()) + base_acc = eval_results["F1 Score"] + best_eval = eval_results + no_improve = 0 + + # Instantiate the loss function + loss_fn = LossFunction(args, device) + logging.info(f"Loss Function: {loss_fn}") + logging.info("Training started.") + for e in range(args.epochs): + for batch_idx, batch in enumerate(tqdm(train_loader)): + model.train() + x, y = batch + optimizer.zero_grad() + + y = y.float().to(device) + if args.loss_fn != "mse": + y = generate_label(y) + + logits = model(x.to(device)) + loss = loss_fn(logits, y) + + loss.backward() + optimizer.step() + + if ((batch_idx + 1) % val_inter == 0 and batch_idx != 0 and batch_idx != len(train_loader)): + valid_eval_results = evaluate_better(model, device, valid_loader, smalltest=False) + model.train() + logging.info(f"[{e}, {batch_idx}] Validation: {eval_print(valid_eval_results)}") + + # evalute on validation set and save the best model + if (e+1) % 1 == 0: + epoch_eval_results = evaluate_better(model, device, valid_loader, args, smalltest=False) + if verbal: + logging.info(f"[Epoch {e+1}] [Valid] {eval_print(epoch_eval_results)}") + + if epoch_eval_results["F1 Score"] > base_acc: + base_acc = epoch_eval_results["F1 Score"] + best_eval = epoch_eval_results + model.cpu() + best_model = copy.deepcopy(model.state_dict()) + model.to(device) + no_improve = 0 + else: + no_improve += 1 + + if no_improve >= early_stop_waiting or base_acc > 0.99: + break + + # After training, log the best evaluation results + logging.info("Training completed.") + logging.info(f"Best Evaluation Results: {eval_print(best_eval)}") + return best_model, best_eval diff --git a/HybridTensor/routers/router_latency.py b/HybridTensor/routers/router_latency.py new file mode 100644 index 0000000000000000000000000000000000000000..78e2d6fbf7d0514d64e2e05aba99a0e5e3817e5e --- /dev/null +++ b/HybridTensor/routers/router_latency.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import argparse +import time +import numpy as np +import os + +''' +python -m HybridTensor.routers.router_latency --block mlp --D 1024 --gpu 0 --model_dim 9216 --out_dim $((9216*4)) +python -m HybridTensor.routers.router_latency --block attn --D 1024 --gpu 0 --model_dim 9216 --heads 72 + +''' + +from HybridTensor.routers.router_utils import DATA, MODEL_CHOICES, DATA_CHOICES, CONFIG +from HybridTensor.utils.profiling import cuda_profiler +from HybridTensor.modules.SelectiveMLP import MLPRouter +from HybridTensor.modules.SelectiveMHA import MHARouter +from HybridTensor.routers.archive.fused_router.router import FusedRouter, HybridFusedRouter + +# --------------------------- +# Latency Measurement Function +# --------------------------- + +class Router(torch.nn.Module): + def __init__(self, args): + super(Router, self).__init__() + self.fc1 = torch.nn.Linear(args.model_dim, args.D, bias=None) + # self.activation = torch.nn.ReLU() + self.bn = torch.nn.BatchNorm1d(args.D) + # output dimension + if args.block == "mlp": + out_dim = args.out_dim + else: # args.block == "attn" + out_dim = args.heads + self.fc2 = torch.nn.Linear(args.D, out_dim, bias=None) + + def forward(self, x): + x = self.fc1(x) + # x = self.bn(x) + x = self.fc2(x) + return x + + +def measure_inference_latency(model, input_data, device, num_runs=100, warm_up=10): + """ + Measures the average inference latency of a PyTorch model. + + Args: + model (torch.nn.Module): The PyTorch model to evaluate. + input_data (torch.Tensor): Input tensor for the model. + device (torch.device): The device to perform inference on. + num_runs (int, optional): Number of forward passes to measure. Defaults to 100. + warm_up (int, optional): Number of warm-up runs before measurement. Defaults to 10. + + Returns: + float: Average inference latency in milliseconds. + """ + model.eval() # Set model to evaluation mode + input_data = input_data.to(device) + model.to(device) + + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + + # Warm-up runs (not timed) + with torch.no_grad(): + for _ in range(warm_up): + _ = model(input_data) + + # Start timing + start_time.record() + with torch.no_grad(): + for _ in range(num_runs): + out = model(input_data) + + end_time.record() + torch.cuda.synchronize() + + avg_latency = start_time.elapsed_time(end_time) / num_runs + + return avg_latency + + +def arg_parser(): + parser = argparse.ArgumentParser(description="PyTorch OPT Full Model") + parser.add_argument("--model", type=str, default="6_7b", choices=MODEL_CHOICES) + parser.add_argument("--model_name", type=str, default="opt", help="model name") + parser.add_argument("--block", type=str, default="mlp", choices=["mlp", "attn"]) + parser.add_argument("--L", type=int, default=0, help="which layer") + parser.add_argument("--D", type=int, default=1024, help="low rank dimension") + parser.add_argument("--batch_size", type=int, default=64, help="batch size") # a smaller batch size results in better precision and recall of the model + parser.add_argument("--ckpt_dir", type=str, default="/home/grads/s//nvme/HybridTensor/checkpoint", help="checkpoint directory") # add a argument for checkpoint dir + parser.add_argument("--gpu", type=int, default=0, help="which gpu to use") # add a argument for which gpu to use + parser.add_argument("--model_dim", type=int, default=8192, help="model dimension") + parser.add_argument("--heads", type=int, default=64, help="number of heads") + parser.add_argument("--dropout", type=float, default=0) + parser.add_argument("--out_dim", type=int, default=32768, help="output dimension") + args = parser.parse_args() + return args + + +# --------------------------- +# 4. Main Execution +# --------------------------- + +if __name__ == "__main__": + # 4.1. Define Arguments + args = arg_parser() + device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") + print(args) + # 4.2. Instantiate the Model + # router_model = Router(args) + mlp_router = MLPRouter(embed_dim=args.model_dim, low_rank_dim=args.D, out_dim=args.out_dim, act_th=0.5, device="cuda:0", dtype=torch.float16) + mha_router = MHARouter(embed_dim=args.model_dim, low_rank_dim=128, out_dim=args.heads, top_k=0.3, device="cuda:0", dtype=torch.float16) + fused_router = FusedRouter(embed_dim=args.model_dim, num_heads=args.heads, dim=args.D, mlp_th=0.5, attn_top_k=0.3, device="cuda:0", dtype=torch.float16) + hybrid_fused_router = HybridFusedRouter(embed_dim=args.model_dim, num_heads=args.heads, mlp_dim=1024, mha_dim=128, mlp_th=0.5, attn_top_k=0.3, device="cuda:0", dtype=torch.float16) + + # 4.3. Select Device + print(f"Using device: {device}") + print(f"Model dimension: {args.model_dim}") + print(f"Block: {args.block}") + print(f"Low rank dimension: {args.D}") + + + # 4.4. Prepare Dummy Input Data + # batch_size = args.batch_size # You can adjust the batch size as needed + # input_dim = CONFIG[args.model]['d'] + input_dim = args.model_dim + + # batch_range= [1, 8, 16, 32, 64] + batch_range = [64] + for batch_size in batch_range: + test_input = torch.randn(batch_size, input_dim, device = device, dtype = torch.float16) # Random input tensor + + selected_neurons, mlp_router_time = cuda_profiler(mlp_router._select_neurons, test_input) + selected_heads, mha_router_time = cuda_profiler(mha_router._select_heads, test_input) + fused_activations, fused_router_time = cuda_profiler(fused_router.select_neurons_heads, test_input) + hybrid_fused_activations, hybrid_fused_router_time = cuda_profiler(hybrid_fused_router.select_neurons_heads, test_input) + hybrid_mha_activations, hybrid_mha_router_time = cuda_profiler(hybrid_fused_router.select_heads_, test_input) + hybrid_mlp_activations, hybrid_mlp_router_time = cuda_profiler(hybrid_fused_router.select_neurons_, test_input) + + # 4.5. Measure Inference Latency + # avg_latency = measure_inference_latency( + # model=router_model, + # input_data=test_input, + # device=device, + # num_runs=100, # Number of measurements + # warm_up=10 # Number of warm-up runs + # ) + + total_router_time = mlp_router_time + mha_router_time + total_hybrid_router_time = hybrid_mlp_router_time + hybrid_mha_router_time + difference = total_router_time - total_hybrid_router_time + + print(f"Batch_size: {batch_size}, MLP Router latency: {mlp_router_time:.4f} ms") + print(f"Batch_size: {batch_size}, MHA Router latency: {mha_router_time:.4f} ms") + print(f"Batch_size: {batch_size}, Total Router latency: {total_router_time:.4f} ms") + # print(f"Batch_size: {batch_size}, Fused Router latency: {fused_router_time:.4f} ms") + + # print(f"Batch_size: {batch_size}, Hybrid Fused Router latency: {hybrid_fused_router_time:.4f} ms") + # print(f"Batch_size: {batch_size}, Hybrid MHA Router latency: {hybrid_mha_router_time:.4f} ms") + # print(f"Batch_size: {batch_size}, Hybrid MLP Router latency: {hybrid_mlp_router_time:.4f} ms") + # print(f"Batch_size: {batch_size}, Total Hybrid Router latency: {total_hybrid_router_time:.4f} ms") + # print(f"Batch_size: {batch_size}, Difference: {difference:.4f} ms") + print() diff --git a/HybridTensor/routers/router_utils.py b/HybridTensor/routers/router_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..820e07974360882894989f1128136224fd52c1e1 --- /dev/null +++ b/HybridTensor/routers/router_utils.py @@ -0,0 +1,499 @@ +import os +import matplotlib.pyplot as plt +from datetime import datetime +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +import json + +import numpy as np +from pathlib import Path +from datetime import datetime + +DATA = { + "175b": { + "c4": "../data/175b_c4", + }, + "66b": { + "c4": "../data/66b_c4", + }, + "30b": { + "c4": "../data/30b_c4", + }, + "6_7b": { + "c4": "/home/grads/s//nvme/ssd_backup/dejavu/data/6_7b_c4" + } +} + +MODEL_CHOICES = ['175b', '66b', '30b', '6_7b'] +DATA_CHOICES = ['c4'] +CONFIG = { + '175b':{ + 'num_layer': 95, + 'ckt_storage': "bylayer", + 'd':12288, + 'h': 96, + 'N':400000, + }, + '66b':{ + 'num_layer': 64, + 'ckt_storage': "bylayer", + 'd':9216, + 'h': 72, + 'N':400000, + }, + '30b':{ + 'num_layer': 24, + 'ckt_storage': "bylayer", + 'd':2048, + 'h': 32, + 'N':400000, + }, + '6_7b':{ + 'num_layer': 32, + 'ckt_storage': "bylayer", + 'd':4096, + 'h': 32, + 'N':400000, + }, +} + +class BasicDataset(Dataset): + def __init__(self, X, Y, n, train ): + self.X = X + self.Y = Y + self.n = n + self.train = train + + def __len__(self): + return self.n + + def __getitem__(self, idx): + if self.train: + x = torch.Tensor(self.X[idx]) + y = torch.Tensor(self.Y[idx]) + else: + x = torch.Tensor(self.X[-idx]) + y = torch.Tensor(self.Y[-idx]) + if y.sum()== 0: + #print("all zero y") + # exit() + pass + return x, y + + +def get_data(args, l): + if CONFIG[args.model]['ckt_storage'] == "bylayer": + #path = f"{DATA[args.model][args.dataset]}/mlp_x_{l}.mmap" + path = f"{DATA[args.model][args.dataset]}/mlp_sp_x_{l}.mmap" + logging.info(f"Reading query from {path}") + query = np.array(np.memmap(path, dtype='float16', mode='r', shape=(400000,CONFIG[args.model]['d']))[: CONFIG[args.model]['N']]) + path = f"{DATA[args.model][args.dataset]}/mlp_label_{l}.mmap" + logging.info(f"Reading MLP label from {path}") + label = np.array(np.memmap(path, dtype='float16', mode='r', shape=(400000,CONFIG[args.model]['d'] * 4))[: CONFIG[args.model]['N']]) + + num_valid = (label.sum(-1) > 0).sum() + return query[:num_valid], label[:num_valid] + #return query, label + +import numpy as np +import os +import json + +# HybridTensor get data +def get_data_(data_dir, layer_idx, data_type, total_neurons = None, total_samples = None): + """ + Load query and label data for a specific layer based on the configuration. + + Args: + args (argparse.Namespace): Contains configuration parameters such as model, dataset, and data_type. + layer_idx (int): The index of the layer to load data for. + + Returns: + tuple: A tuple containing: + - query (np.ndarray): The query data array of shape (num_valid, feature_size_query). + - label (np.ndarray): The label data array of shape (num_valid, feature_size_label). + """ + + # Ensure data_type is valid + if data_type not in ['mlp_activations', 'attn_norms']: + raise ValueError(f"Invalid data_type: {data_type}. Must be 'mlp_activations' or 'attn_norms'.") + + # logging.info(f"Loading data from directory: {data_dir}") + + # Load metadata + metadata_path = os.path.join(data_dir, 'metadata.json') + if not os.path.exists(metadata_path): + raise FileNotFoundError(f"Metadata file not found at {metadata_path}") + + with open(metadata_path, 'r') as f: + metadata = json.load(f) + + num_layers = metadata['num_layers'] + hidden_size = metadata['hidden_size'] + num_heads = metadata['num_heads'] + max_samples = metadata['max_samples'] + if total_neurons == None: + total_neurons = hidden_size * 4 + + # Validate layer index + if layer_idx < 0 or layer_idx >= num_layers: + raise ValueError(f"Invalid layer_idx: {layer_idx}. Must be between 0 and {num_layers - 1}.") + + # Determine feature sizes and corresponding file names + if data_type == 'mlp_activations': + label_feature_size = total_neurons + label_filename = f"mlp_activations_layer_{layer_idx}.dat" + elif data_type == 'attn_norms': + label_feature_size = num_heads + label_filename = f"attn_norms_layer_{layer_idx}.dat" + + # Query corresponds to hidden_states + query_feature_size = hidden_size + query_filename = f"hidden_states_layer_{layer_idx}.dat" + + # Paths to the mmap files + query_path = os.path.join(data_dir, query_filename) + label_path = os.path.join(data_dir, label_filename) + + # Log the paths + logging.info(f"Reading query from {query_path}") + logging.info(f"Reading label from {label_path}") + + # Determine number of samples to load + if total_samples is not None: + if total_samples <= 0: + raise ValueError(f"total_samples must be a positive integer, got {total_samples}") + num_samples = min(total_samples, max_samples) + else: + num_samples = max_samples + + # Load query data + if not os.path.exists(query_path): + raise FileNotFoundError(f"Query file not found at {query_path}") + + query_mmap = np.memmap(query_path, dtype='float16', mode='r', shape=(max_samples, query_feature_size)) + query = np.array(query_mmap[:num_samples]) + del query_mmap # Close the memmap + + # Load label data + if not os.path.exists(label_path): + raise FileNotFoundError(f"Label file not found at {label_path}") + + label_mmap = np.memmap(label_path, dtype='float16', mode='r', shape=(max_samples, label_feature_size)) + label = np.array(label_mmap[:num_samples]) + del label_mmap # Close the memmap + + logging.info(f"Loaded {label.shape[0]} valid samples for layer {layer_idx}.") + + return query, label + + +def get_expert_config(expert_dir, layer_idx): + expert_file = f"{expert_dir}/layer_{layer_idx}_experts.json" + + # if expert_file does not exist, then we need to create the file before starting the training + if os.path.exists(expert_file): + # load the expert file + with open(expert_file, 'r') as f: + experts = json.load(f) + print(f"Expert file {expert_file} read successfully.") + else: + raise FileNotFoundError(f"Expert file {expert_file} not found.") + return experts + + +def augment_data(labels, device='cpu'): + """ + Data augmentation function that processes the input query and labels, + identifies hot and cold neurons based on activation counts, and returns + the cold labels (corresponding to cold neurons). + + Args: + query (torch.Tensor): Input data or query tensor. + labels (torch.Tensor): Corresponding labels tensor. + device (str): The device ('cpu' or 'cuda') to perform operations on. + + Returns: + torch.Tensor: Cold labels corresponding to the cold neurons. + """ + + # Convert labels into binary labels (labels > 0) + # print("Convert labels into binary labels (labels > 0):") + binary_labels = (labels > 0).astype(int) + + # print("Binary Labels:", binary_labels) + # print("Binary Labels shape ", binary_labels.shape) + + # Count activations for each neuron + activation_counts = binary_labels.sum(axis=0) + activation_counts = torch.tensor(activation_counts, device=device) + + # Sort neurons by activation counts in descending order + sorted_counts, sorted_indices = torch.sort(activation_counts, descending=True) + + # Compute cumulative sum of sorted activation counts + cumulative_counts = torch.cumsum(sorted_counts, dim=0) + + # Compute total activations + total_activations = cumulative_counts[-1] + + # Find the index where cumulative sum reaches 80% of total activations + threshold = 0.8 * total_activations + hot_neuron_threshold_idx = torch.nonzero(cumulative_counts >= threshold)[0].item() + + # Indices of hot and cold neurons + hot_neuron_indices = sorted_indices[:hot_neuron_threshold_idx + 1] + cold_neuron_indices = sorted_indices[hot_neuron_threshold_idx + 1:] + + # Move indices back to CPU + cold_neuron_indices = cold_neuron_indices.cpu() + hot_neuron_indices = hot_neuron_indices.cpu() + + # Filter cold and pruned (zero-activation) neurons + pruned_cold_neurons = torch.nonzero(activation_counts == 0).cpu().flatten() + filtered_cold_neurons = cold_neuron_indices[~cold_neuron_indices.unsqueeze(1).eq(pruned_cold_neurons).any(dim=1)] + + logging.info(f"Hot neurons count: {len(hot_neuron_indices)}") + logging.info(f"Cold neurons count: {len(cold_neuron_indices)}") + logging.info(f"Non-zero neurons count: {len(filtered_cold_neurons)}") + logging.info(f"Zero neurons count: {len(pruned_cold_neurons)}") + + # Extract the cold neurons from the binary labels + cold_labels = labels[:, filtered_cold_neurons] + + return (hot_neuron_indices, filtered_cold_neurons, pruned_cold_neurons), cold_labels + + +def create_dataset(query, labels, args): + total = len(query) + num_train = int(0.95 * total) + num_test = int(0.05 * total) + + logging.info(f"Query shape: {query.shape}, Label shape: {labels.shape}") + logging.info(f"# training data: {num_train}, # test data: {num_test}") + + train_ds = BasicDataset(query, labels, num_train, True) + test_ds = BasicDataset(query, labels, num_test, False) + + train_dataloader = DataLoader( + train_ds, args.batch_size, shuffle=True, num_workers=0 + ) + test_dataloader = DataLoader(test_ds, args.batch_size, shuffle=False, num_workers=0) + return train_dataloader, test_dataloader + +def create_log_path(model: str, layer_idx: int) -> str: + """ + Creates the necessary directories and generates a log file path. + + Parameters: + - model (str): The name of the model. + - layer_idx (int): The index of the MLP layer. + + Returns: + - str: The full path to the log file. + """ + # Define the base log directory + base_log_dir = 'log' + + # Get today's date in YYYY-MM-DD format + today_str = datetime.now().strftime('%Y-%m-%d') + + # Path to today's log directory + today_log_dir = os.path.join(base_log_dir, today_str) + + # Create the directories if they don't exist + os.makedirs(today_log_dir, exist_ok=True) + + # Get the current time in HH-MM-SS format for the filename + current_time_str = datetime.now().strftime('%H-%M-%S') + + # Create the log filename + log_name = f'mlp_{model}_{layer_idx}_{current_time_str}.log' + + # Combine the directory and filename to get the full path + log_path = os.path.join(today_log_dir, log_name) + + return log_path + +def generate_label(y): + # positive + one_hot = (y > 0).to(y.dtype) + return one_hot + +def evaluate(model, device, loader, args, smalltest=False): + model.eval() + + eval = { + "Loss": [], + "Loss Weight": [], + "Recall": [], + "Classifier Sparsity": [], + "True Sparsity": [], + } + + with torch.no_grad(): + for batch_idx, batch in enumerate((loader)): + x, y = batch + y = y.float().to(device) + + y = generate_label(y) + + logits = model(x.to(device)) + probs = logits.sigmoid() + preds = probs >= 0.5 + + dif = y.int() - preds.int() + miss = dif > 0.0 # classifier didn't activated target neuron + + weight = (y.sum() / y.numel()) + 0.005 + loss_weight = y * (1 - weight) + weight + eval["Loss Weight"] += [weight.item()] + eval["Loss"] += [ + torch.nn.functional.binary_cross_entropy( + probs, y, weight=loss_weight + ).item() + ] + + eval["Recall"] += [ + ((y.sum(dim=1).float() - miss.sum(dim=1).float()).mean().item()) + ] + eval["True Sparsity"] += [y.sum(dim=1).float().mean().item()] + eval["Classifier Sparsity"] += [preds.sum(dim=1).float().mean().item()] + + if batch_idx >= 100 and smalltest: + break + + for k, v in eval.items(): + eval[k] = np.array(v).mean() + + eval["Recall"] = eval["Recall"] / eval["True Sparsity"] + return eval + +def generate_experiment_id(): + """ + Generates a unique experiment identifier based on the current time in HHMMSS format. + + Returns: + - str: A string representing the current time as HHMMSS. + """ + return datetime.now().strftime('%H%M%S') + +def create_date_directory(ckpt_dir: str, model_name: str) -> Path: + # Step 2: Define the base directory and create date-based directory + base_dir = f"{ckpt_dir}/opt-{model_name}-sparse-predictor" + base_dir_path = Path(base_dir) + base_dir_path.mkdir(parents=True, exist_ok=True) # Ensure base directory exists + + today_str = datetime.now().strftime('%Y-%m-%d') + date_dir = Path(base_dir) / today_str + date_dir.mkdir(parents=True, exist_ok=True) + return date_dir + +def generate_log_filename(experiment_id: str, model_name: str, layer_idx: int, date_dir: str) -> str: + log_filename = f"{experiment_id}_{model_name}_mlp_layer{layer_idx}.log" + log_file_path = date_dir / log_filename + return log_file_path + +def generate_model_filename(experiment_id: str, model_name: str, layer_idx: int, eval_result: dict) -> str: + recall = eval_result.get('Recall', 0) + precision = eval_result.get('Precision', 0) + sparsity = eval_result.get('Classifier Sparsity', 0) + return f"{experiment_id}_{model_name}_mlp_layer{layer_idx}_-{recall:.4f}-{precision:.4f}-{sparsity:.0f}.pt" + +def setup_logging(log_file= Path): + """ + Sets up the logging configuration. + + Parameters: + - log_file (str): Path to the log file. + """ + logging.basicConfig( + level=logging.INFO, # Set the logging level + format='%(asctime)s - %(levelname)s - %(message)s', # Log message format + datefmt='%Y-%m-%d %H:%M:%S', # Date format + handlers=[ + logging.FileHandler(log_file, mode='w'), # Write logs to a file + logging.StreamHandler() # Also output logs to the console + ] + ) + + + +def plot_neuron_activation_frequency( + labels: torch.Tensor, + threshold: float = 0.0, + max_xticks: int = 20, + figsize: tuple = (12, 6), + title: str = 'Neuron Activation Frequency in MLP Layer', + color: str = 'skyblue' +): + """ + Plots the activation frequency of neurons in a Multi-Layer Perceptron (MLP) layer. + + Args: + labels (torch.Tensor): A 2D tensor of shape (num_samples, num_neurons) containing label data. + threshold (float, optional): Threshold to convert labels into binary. Labels > threshold are set to 1. Defaults to 0.0. + max_xticks (int, optional): Maximum number of x-axis ticks to display. Defaults to 20. + figsize (tuple, optional): Size of the matplotlib figure. Defaults to (12, 6). + title (str, optional): Title of the plot. Defaults to 'Neuron Activation Frequency in MLP Layer'. + color (str or list, optional): Color of the bars in the plot. Defaults to 'skyblue'. + + Raises: + ValueError: If `labels` is not a 2D tensor. + """ + + binary_labels = (labels > threshold).astype(int) + activation_counts = binary_labels.sum(axis=0) + + # Get sorted indices in descending order + sorted_indices = np.argsort(-activation_counts) + + # Sort the activation counts accordingly + sorted_counts = activation_counts[sorted_indices] + + + plt.figure(figsize=figsize) + plt.bar(range(len(sorted_counts)), sorted_counts, color=color) + plt.xlabel('Neurons (sorted by activation frequency)') + plt.ylabel('Activation Frequency') + plt.title(title) + plt.xticks(ticks=range(0, len(sorted_counts), max(1, len(sorted_counts)//20)), + labels=sorted_indices[::max(1, len(sorted_counts)//20)], + rotation=90) # Adjust ticks for better readability + plt.tight_layout() + plt.show() + +# Argument parser for jupyter notebook +from unittest.mock import patch +import argparse +import sys + +def arg_parser_notebook(test_args=[None]): + parser = argparse.ArgumentParser(description="PyTorch OPT Full Model") + parser.add_argument("--model", type=str, default="6_7b", choices=MODEL_CHOICES) + parser.add_argument("--model_name", type=str, default="opt", help="model name") + parser.add_argument("--dataset", type=str, default="c4", choices=DATA_CHOICES) + parser.add_argument("--L", type=int, default=0, help="which layer") + parser.add_argument("--D", type=int, default=1024, help="low rank dimension") + parser.add_argument("--batch_size", type=int, default=64, help="batch size") # a smaller batch size results in better precision and recall of the model + parser.add_argument("--epochs", type=int, default=20, help="epochs") + parser.add_argument("--lr", type=float, default=0.0001, help="learning rate") + parser.add_argument("--ckpt_dir", type=str, default="/home/grads/s//nvme/HybridTensor/checkpoint", help="checkpoint directory") # add a argument for checkpoint dir + parser.add_argument("--gpu", type=int, default=0, help="which gpu to use") # add a argument for which gpu to use + parser.add_argument("--alpha", type=float, default=1) + parser.add_argument("--gamma", type=float, default=1) + parser.add_argument("--threshold", type=float, default=0.5) + parser.add_argument("--loss_fn", type=str, default="focal") + parser.add_argument("--data_augmentation", type=bool, default=False) # use cold neurons for routing + + with patch.object(sys, 'argv', test_args): + parser = arg_parser_notebook() + args = parser.parse_args() + args_dict = vars(args) # Convert args to a dictionary for display + + return args_dict \ No newline at end of file diff --git a/HybridTensor/triton/__init__.py b/HybridTensor/triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/HybridTensor/triton/__pycache__/__init__.cpython-310.pyc b/HybridTensor/triton/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9846dc1dfe5ca30d7186e524fef033fe598f2da3 Binary files /dev/null and b/HybridTensor/triton/__pycache__/__init__.cpython-310.pyc differ diff --git a/HybridTensor/triton/__pycache__/__init__.cpython-39.pyc b/HybridTensor/triton/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4052a7219187982857fa48eede194c91ed7eb54a Binary files /dev/null and b/HybridTensor/triton/__pycache__/__init__.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/activations.cpython-39.pyc b/HybridTensor/triton/__pycache__/activations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d27b51b9fed908a45d24169845d3052e93a09a66 Binary files /dev/null and b/HybridTensor/triton/__pycache__/activations.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/attention_proj_sparse.cpython-310.pyc b/HybridTensor/triton/__pycache__/attention_proj_sparse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cc85f85e7cda6f8f2c1c0393cfdabe5559a38be Binary files /dev/null and b/HybridTensor/triton/__pycache__/attention_proj_sparse.cpython-310.pyc differ diff --git a/HybridTensor/triton/__pycache__/attention_proj_sparse.cpython-39.pyc b/HybridTensor/triton/__pycache__/attention_proj_sparse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c114e9a55a384410e1fc9e68b8331cd6cf38b85 Binary files /dev/null and b/HybridTensor/triton/__pycache__/attention_proj_sparse.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/attn_decode.cpython-39.pyc b/HybridTensor/triton/__pycache__/attn_decode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f7e053a8f8be0aa62458e190f0ff84d07ba3dc3 Binary files /dev/null and b/HybridTensor/triton/__pycache__/attn_decode.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/attn_interface.cpython-39.pyc b/HybridTensor/triton/__pycache__/attn_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b452393cc233f7b9bbad2bcd19806819b93a04a2 Binary files /dev/null and b/HybridTensor/triton/__pycache__/attn_interface.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/autotune_configs.cpython-39.pyc b/HybridTensor/triton/__pycache__/autotune_configs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a7b6c7065ed919b250c600cd4f268318d837acb Binary files /dev/null and b/HybridTensor/triton/__pycache__/autotune_configs.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/fused_attn.cpython-39.pyc b/HybridTensor/triton/__pycache__/fused_attn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..109e339a1adaf75b2b4f379771764e17f4901fa0 Binary files /dev/null and b/HybridTensor/triton/__pycache__/fused_attn.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/fused_attn_decode.cpython-39.pyc b/HybridTensor/triton/__pycache__/fused_attn_decode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd4d77ec24bf479b9d28215f119cd474af0126e0 Binary files /dev/null and b/HybridTensor/triton/__pycache__/fused_attn_decode.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/gather_gemm_col.cpython-310.pyc b/HybridTensor/triton/__pycache__/gather_gemm_col.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee08dbf8b590bacbf7d104efae98c67b5c942f19 Binary files /dev/null and b/HybridTensor/triton/__pycache__/gather_gemm_col.cpython-310.pyc differ diff --git a/HybridTensor/triton/__pycache__/gather_gemm_col.cpython-39.pyc b/HybridTensor/triton/__pycache__/gather_gemm_col.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..823a76a7d305d6143d183cbb0d78b0be13669691 Binary files /dev/null and b/HybridTensor/triton/__pycache__/gather_gemm_col.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/gather_gemm_row.cpython-310.pyc b/HybridTensor/triton/__pycache__/gather_gemm_row.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82cbe03c0f2567bc49241dff93ff1c491b39bfb9 Binary files /dev/null and b/HybridTensor/triton/__pycache__/gather_gemm_row.cpython-310.pyc differ diff --git a/HybridTensor/triton/__pycache__/gather_gemm_row.cpython-39.pyc b/HybridTensor/triton/__pycache__/gather_gemm_row.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9774861f8ce8054bfddf72a5fba864921e7229e4 Binary files /dev/null and b/HybridTensor/triton/__pycache__/gather_gemm_row.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/gemm.cpython-39.pyc b/HybridTensor/triton/__pycache__/gemm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39cf44f8da2a402068b31390774fe4ce5d2c1050 Binary files /dev/null and b/HybridTensor/triton/__pycache__/gemm.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/heurisitcs.cpython-39.pyc b/HybridTensor/triton/__pycache__/heurisitcs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5b506432a9642654284d9bfa142a6d1998baeeb Binary files /dev/null and b/HybridTensor/triton/__pycache__/heurisitcs.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/profile_gather_gemm.cpython-39.pyc b/HybridTensor/triton/__pycache__/profile_gather_gemm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1d65a9353eb35f1b3a1808248001e1351c9818b Binary files /dev/null and b/HybridTensor/triton/__pycache__/profile_gather_gemm.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/select_attn.cpython-310.pyc b/HybridTensor/triton/__pycache__/select_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..348eb9e7ecc7bb5202aa1f39028cc733fc9a28d7 Binary files /dev/null and b/HybridTensor/triton/__pycache__/select_attn.cpython-310.pyc differ diff --git a/HybridTensor/triton/__pycache__/select_attn.cpython-39.pyc b/HybridTensor/triton/__pycache__/select_attn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64e028327cd0b7081f2897d179f2a45bcbc790c3 Binary files /dev/null and b/HybridTensor/triton/__pycache__/select_attn.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/select_attn_64b_kernel.cpython-310.pyc b/HybridTensor/triton/__pycache__/select_attn_64b_kernel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..955b15836f28ee191c0425e0a97bba90d542d66b Binary files /dev/null and b/HybridTensor/triton/__pycache__/select_attn_64b_kernel.cpython-310.pyc differ diff --git a/HybridTensor/triton/__pycache__/select_attn_64b_kernel.cpython-39.pyc b/HybridTensor/triton/__pycache__/select_attn_64b_kernel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..471cceecf795551bee930de75c5cb63411238e01 Binary files /dev/null and b/HybridTensor/triton/__pycache__/select_attn_64b_kernel.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/select_attn_v1.cpython-310.pyc b/HybridTensor/triton/__pycache__/select_attn_v1.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d86f3b00d55cade091a4a2d8af4f39dc45b4ad53 Binary files /dev/null and b/HybridTensor/triton/__pycache__/select_attn_v1.cpython-310.pyc differ diff --git a/HybridTensor/triton/__pycache__/select_attn_v1.cpython-39.pyc b/HybridTensor/triton/__pycache__/select_attn_v1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ad1d0a05cee88573021ef82a1b3f72229884750 Binary files /dev/null and b/HybridTensor/triton/__pycache__/select_attn_v1.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/select_gqa.cpython-39.pyc b/HybridTensor/triton/__pycache__/select_gqa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..348c833052c7d887cc8f1296a1feb02f10e4e977 Binary files /dev/null and b/HybridTensor/triton/__pycache__/select_gqa.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/triton_flashattn_decode.cpython-39.pyc b/HybridTensor/triton/__pycache__/triton_flashattn_decode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da4976fb5e457f27f6a6cf88e4dee57c0eaef55f Binary files /dev/null and b/HybridTensor/triton/__pycache__/triton_flashattn_decode.cpython-39.pyc differ diff --git a/HybridTensor/triton/__pycache__/triton_utils.cpython-310.pyc b/HybridTensor/triton/__pycache__/triton_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..856b43ef42fc0771a939859d51aff6624be54229 Binary files /dev/null and b/HybridTensor/triton/__pycache__/triton_utils.cpython-310.pyc differ diff --git a/HybridTensor/triton/__pycache__/triton_utils.cpython-39.pyc b/HybridTensor/triton/__pycache__/triton_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eb45287177b7e386b227efc052cc00bf8ea0844 Binary files /dev/null and b/HybridTensor/triton/__pycache__/triton_utils.cpython-39.pyc differ diff --git a/HybridTensor/triton/activations.py b/HybridTensor/triton/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..4f767b4e0c7069e9694c95225c428b9c8ac7b1dc --- /dev/null +++ b/HybridTensor/triton/activations.py @@ -0,0 +1,169 @@ +import math +from enum import Enum +from typing import Optional + + +import torch +import triton +import triton.language as tl + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + +# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +_sqrt2pi = math.sqrt(2.0 / math.pi) +_sqrt1_2 = math.sqrt(1.0 / 2) +_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) + + +class Activation(str, Enum): + SquaredReLU = "squared_relu" + GeLU = "gelu" + GeLUApprox = "gelu_approx" + LeakyReLU = "leaky_relu" + ReLU = "relu" + + +def get_triton_activation_kernel(activation: Optional[Activation]): + return ( + { + Activation.ReLU: relu, + Activation.LeakyReLU: leaky_relu, + Activation.GeLU: gelu, + Activation.GeLUApprox: gelu_approx, + Activation.SquaredReLU: squared_relu, + }[activation] + if activation + else None + ) + + +def get_triton_activation_bwd_kernel(activation: Optional[Activation]): + return ( + { + Activation.ReLU: relu_grad, + Activation.LeakyReLU: leaky_relu_grad, + Activation.GeLU: gelu_grad, + Activation.GeLUApprox: gelu_approx_grad, + Activation.SquaredReLU: squared_relu_grad, + }[activation] + if activation + else None + ) + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def cosh(x): + exp_x = tl.exp(x) + return (exp_x + 1.0 / exp_x) * 0.5 + + +# a Triton implementation of the most used activations +# See for instance http://arxiv.org/abs/1606.08415 for an overview + +# ReLU +@triton.jit +def relu(x): + """ + ReLU_ activation function + + .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html + """ + zero = 0.0 + return tl.where(x >= 0, x, zero.to(x.dtype)) + + +@triton.jit +def relu_grad(x): + # ReLU is different from other activations + # in that it does not require the input to retrospectively compute its gradient + # here the input is the downstream gradient, and we return the upstream gradient directly + zero = 0.0 + one = 1.0 + return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) + + +@triton.jit +def squared_relu(x): + """ + Squared ReLU activation, as proposed in the Primer_ paper. + + .. _Primer: https://arxiv.org/abs/2109.08668 + """ + x_ = relu(x) + return (x_ * x_).to(x.dtype) + + +@triton.jit +def squared_relu_grad(x): + return tl.where(x >= 0, 2.0 * x, 0.0) + + +# Leaky ReLU +@triton.jit +def leaky_relu(x): + """ + LeakyReLU_ activation + + .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html + """ + scale = 0.01 + 0.0 + scale = scale.to(x.dtype) + return tl.where(x >= 0, x, scale * x) + + +@triton.jit +def leaky_relu_grad(x): + min_grad = 0.01 + max_grad = 1 + + min_grad = min_grad.to(x.dtype) + max_grad = max_grad.to(x.dtype) + + return tl.where(x >= 0, max_grad, min_grad) + + +@triton.jit +def gelu(x): + """Gaussian Error Linear Unit (GELU)""" + return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) + + +@triton.jit +def gelu_grad(x): + cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) + pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization + return cdf + x * pdf + +@triton.jit +def gelu_approx(x): + """ + GeLU_ activation - Gaussian error linear unit, with tanh approximation + + .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf + """ + return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x))) + + +@triton.jit +def gelu_approx_grad(x): + # CREDITS: Fast implementation proposed in + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 + tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + return 0.5 * x * ( + (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) + ) + 0.5 * (1 + tanh_out) diff --git a/HybridTensor/triton/attn_interface.py b/HybridTensor/triton/attn_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..c0778eead7dc6a006ead93baf86ef3377b209d2d --- /dev/null +++ b/HybridTensor/triton/attn_interface.py @@ -0,0 +1,193 @@ +import torch + +from HybridTensor.triton.select_gqa import select_gqa +from HybridTensor.triton.select_attn_v1 import select_attn +from HybridTensor.triton.triton_flashattn_decode import attention_decode_forward_triton_impl +from typing import Optional, Sequence, Tuple, Union + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + +def flash_attn_with_kvcache_triton( + q, + k_cache, + v_cache, + k=None, + v=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + alibi_slopes=None, + num_splits=0, + return_softmax_lse=False, + # head index + batch_head_idx=None, +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + batch_size, q_seqlen, nheads, head_dim = q.shape + batch_size, seqlen_cache, nheads_k, head_dim = k_cache.shape + + # the kvcache has been updated, the cache_seqlens = cache_seqlens + q_seqlen + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens + q_seqlen, dtype=torch.int32, device=k_cache.device + ) + cache_seqlens = maybe_contiguous(cache_seqlens) + cache_batch_idx = maybe_contiguous(cache_batch_idx) + block_table = maybe_contiguous(block_table) + + if batch_head_idx is None: + # dense attention + out, _ = attention_decode_forward_triton_impl( + q=q, + k = k_cache, + v = v_cache, + sm_scale=softmax_scale, + causal=causal, + alibi_slopes=alibi_slopes, + layout="bshd", + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + new_kv=False, + k_new=None, + v_new=None, + ) + else: + if nheads != nheads_k: + # select GQA + out, _ = select_gqa( + q=q, + k = k_cache, + v = v_cache, + sm_scale=softmax_scale, + causal=causal, + alibi_slopes=alibi_slopes, + layout="bshd", + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + new_kv=False, # make sure to update the kv cache outside + k_new=None, + v_new=None, + batch_group_index=batch_head_idx, + ) + + else: + # select MHA + # TODO: select MHA requires addtional operations to have the group dimensions before using the kernel, + # unsqueeze(2) and remove the group dim after the kernel + q = q.unsqueeze(2) + k_cache = k_cache.unsqueeze(2) + v_cache = v_cache.unsqueeze(2) + + out = select_attn( + q, + k_cache, + v_cache, + softmax_scale, + batch_head_idx, + cache_seqlens, + ) + out = out.view(batch_size, q_seqlen, nheads, head_dim) + + return out diff --git a/HybridTensor/triton/cg_safe/gather_gemm_col_cg.py b/HybridTensor/triton/cg_safe/gather_gemm_col_cg.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2a97988255b33d209db56256e68aefdae9de3b --- /dev/null +++ b/HybridTensor/triton/cg_safe/gather_gemm_col_cg.py @@ -0,0 +1,216 @@ +import torch +import triton +import triton.language as tl +from HybridTensor.triton.triton_utils import get_autotune_config, benchmark_fwd, torch_sel_gemm_col +from HybridTensor.triton.activations import leaky_relu, relu +from HybridTensor.utils.utils import sparse_index, arg_parser +from HybridTensor.utils.profiling import cuda_profiler + +@triton.autotune( + configs=get_autotune_config(), + # key=['M', 'gather_N', 'K'], + key=['M', 'K'], +) +@triton.jit +def matmul_gather_kernel_col( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + index_vec_ptr, # Pointer to the index vector for selected columns in B + bias_ptr, + # Matrix dimensions + M, N, K, index_size, + # Strides + stride_am, stride_ak, # + stride_bk, stride_bn, # B is column-major + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr, # + USE_BIAS: tl.constexpr +): + """Kernel for computing the matmul C = A x B_select, + where B_select contains columns of B selected using index_vec. + A has shape (M, K) in row-major and B has shape (K, N) in column-major. + index_vec is used to select the columns of B for the matmul. + """ + # ----------------------------------------------------------- + # Map program ids to block of C + pid = tl.program_id(axis=0) + # in_size = tl.load(index_size) + in_size = index_size + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(in_size, BLOCK_SIZE_N) # Adjusted for the gathered columns + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # Exit early if this block's column start exceeds active columns. + if pid_n * BLOCK_SIZE_N >= in_size: + return + + # ----------------------------------------------------------- + # Create pointers for the first blocks of A and B + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % in_size + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Pointer arithmetic for A (row-major, K contiguous) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + + # Gather selected column indices from B using the index_vec + gather_idx = tl.load(index_vec_ptr + offs_bn) # Gather indices from the index_vec + b_ptrs = b_ptr + (gather_idx[None, :] * stride_bk + offs_k[:, None] * stride_bn) + + if USE_BIAS: + bias = tl.load(bias_ptr + gather_idx, mask=offs_bn < in_size, other=0.0) + + # ----------------------------------------------------------- + # Initialize the accumulator for C in FP32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load blocks of A and B with masking for out-of-bounds K + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # Perform dot product and accumulate + accumulator = tl.dot(a, b, accumulator) + # Advance pointers for next iteration over K + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bn # Move along K dimension + + if USE_BIAS: + accumulator += bias[None, :] + + # Optional activation + accumulator = relu(accumulator) + # if ACTIVATION == "leaky_relu": + # accumulator = leaky_relu(accumulator) + # elif ACTIVATION == "relu": + # accumulator = relu(accumulator) + + # Convert the accumulator back to FP16 + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the result into matrix C + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < in_size) + tl.store(c_ptrs, c, mask=c_mask) + +def gather_matmul_col(a, b, index_vec, index_size, bias=None, activations="", out=None): + # a must be contiguous; b is expected in column-major format. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + if bias is not None: + assert bias.size(0) == b.shape[0], "Incompatible bias dimensions" + use_bias = bias is not None + + M, K = a.shape + N, K = b.shape + # Static gather_N equals the length of index_vec (total number of columns) + # gather_N = index_vec.numel() + + + if out is None: + out = torch.empty((M, N), device=a.device, dtype=torch.float16) + + # Use a static grid based on the total number of columns. + # grid = lambda META: (tl.cdiv(M, META['BLOCK_SIZE_M']) * tl.cdiv(gather_N, META['BLOCK_SIZE_N']), ) + # Compute grid dimensions on the host. + grid = lambda META: ( + ((M + META['BLOCK_SIZE_M'] - 1) // META['BLOCK_SIZE_M']) * + ((N + META['BLOCK_SIZE_N'] - 1) // META['BLOCK_SIZE_N']), + ) + + matmul_gather_kernel_col[grid]( + a, b, out, index_vec, bias, + M, N, K, index_size, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + out.stride(0), out.stride(1), + ACTIVATION=activations, + USE_BIAS=use_bias + ) + return out + +def torch_sel_gemm_col(a, b, index_vec, index_size, bias=None, out=None): + """ + Computes the selective GEMM in torch to match the Triton kernel. + Only the first `index_size` columns (as given by index_vec) are computed; + the remaining columns in the output (of shape determined by index_vec) are zeros. + Assumes ReLU activation. + """ + # Total static number of columns (gather_N) + gather_N = index_vec.numel() + M = a.shape[0] + if out is None: + out = torch.zeros((M, gather_N), device=a.device, dtype=a.dtype) + + # Compute only for the active indices (first `index_size` elements of index_vec) + selected_indices = index_vec[:index_size] + active_result = torch.matmul(a, b[:, selected_indices]) + if bias is not None: + active_result = active_result + bias[selected_indices] + + active_result = torch.nn.functional.relu(active_result) + + # Place the computed results in the first index_size columns; the rest remain zero. + out[:, :index_size] = active_result + return out + + +if __name__ == '__main__': + args = arg_parser() + torch.manual_seed(0) + A = torch.randn(args.batch_size, args.in_features, device='cuda', dtype=torch.float16) + B = torch.randn(args.in_features, args.hidden_features, device='cuda', dtype=torch.float16) + # Ensure B is column-major. + B_col_major = B.t().contiguous() + + # Create a static index vector (length equals total columns in B). + index_vec = torch.empty((args.hidden_features,), device='cuda', dtype=torch.int32) + active_indices = sparse_index(args.index_size, args.hidden_features)[0] + index_vec[:args.index_size] = active_indices + if args.index_size < args.hidden_features: + index_vec[args.index_size:] = 0 # Fill the rest with dummy values. + + if args.bias: + print("Using bias") + bias = torch.randn(args.hidden_features, device='cuda', dtype=torch.float16) + else: + print("No bias") + bias = None + + print(f"args: {args}") + print(f"Index size: {args.index_size}, Activated {args.index_size/(args.in_features * 4)*100}% neurons") + index_size = torch.tensor(args.index_size, device='cuda', dtype=torch.int32) + + # Benchmark standard matmul. + out_b, cublass_time = cuda_profiler(torch.matmul, A, B) + print(f"Cublass time: {cublass_time:.2f} ms") + # Benchmark the CUDA Graph safe selective GEMM. + out_a_gather, gather_gemm_time = cuda_profiler(gather_matmul_col, A, B_col_major, index_vec, index_size, bias=bias) + print(f"Gather kernel time: {gather_gemm_time:.2f} ms") + # Benchmark the torch-based selective GEMM. + out_a_gather_torch, torch_sel_gemm_time = cuda_profiler(torch_sel_gemm_col, A, B, index_vec, index_size, bias=bias) + print(f"Torch gather time: {torch_sel_gemm_time:.2f} ms") + + speedup = cublass_time / gather_gemm_time + + if args.check_results: + print("Checking results") + print("Kernel gather output: ", out_a_gather) + print("Torch gather output: ", out_a_gather_torch) + if torch.allclose(out_a_gather, out_a_gather_torch, atol=0.5): + print("Results match ✅") + else: + print("Results do not match ❌") + print("Max difference: ", torch.max(torch.abs(out_a_gather - out_a_gather_torch))) + + print(f"Speedup: {speedup:.2f}") diff --git a/HybridTensor/triton/cg_safe/gather_gemm_row_cg.py b/HybridTensor/triton/cg_safe/gather_gemm_row_cg.py new file mode 100644 index 0000000000000000000000000000000000000000..0e47f6bf9f8797e6589d1758106112b920fc0cfa --- /dev/null +++ b/HybridTensor/triton/cg_safe/gather_gemm_row_cg.py @@ -0,0 +1,302 @@ +import torch + +import triton +import triton.language as tl +from HybridTensor.triton.triton_utils import get_autotune_config, benchmark_fwd, torch_sel_gemm_row +from HybridTensor.triton.activations import leaky_relu, relu + +from HybridTensor.utils.utils import sparse_index, arg_parser +from HybridTensor.utils.profiling import cuda_profiler + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=get_autotune_config(), + # key=['M', 'N', 'index_size'], + key=['M', 'N'], +) + +@triton.jit +def matmul_gather_kernel_row( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, index_vec_ptr, bias_ptr, + # Matrix dimensions + M, N, K, index_size, + # The stride variables + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + USE_BIAS: tl.constexpr +): + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + pid = tl.program_id(axis=0) + # in_size = index_size + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Compute block indices + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # Initialize accumulator + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + if USE_BIAS: + # Load the bias vector + bias = tl.load(bias_ptr + offs_bn, mask=offs_bn < N, other=0.0) + + # Loop over the K dimension in blocks + for k in range(0, tl.cdiv(index_size, BLOCK_SIZE_K)): + offs_k = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + # Load indices for the current block + index_vec = tl.load(index_vec_ptr + offs_k, mask=offs_k < index_size, other=0) + + # Compute pointers to A and B + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (index_vec[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # Load blocks of A and B + a = tl.load(a_ptrs, mask=(offs_am[:, None] < M) & (offs_k[None, :] < index_size), other=0.0) + b = tl.load(b_ptrs, mask=(index_vec[:, None] < K) & (offs_bn[None, :] < N), other=0.0) + + # Accumulate partial results + accumulator += tl.dot(a, b) + + if USE_BIAS: + # Add the bias + accumulator += bias[None, :] # Broadcast bias over rows + + # Cast accumulator to desired type and store the result + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def gather_matmul_row(a, b, index_vec, index_size, bias = None, out = None): + # Check constraints. + # a and b are expected to be in row major format. + # assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + + M, _ = a.shape + K, N = b.shape + # if index_size is torch.tensor, convert to int + index_size = index_size.item() if torch.is_tensor(index_size) else index_size + # index_size = index_size.item() # potentially not CG safe + # index_size = index_vec.shape[0] + # Allocates output. + if out is None: + out = torch.empty((M, N), device=a.device, dtype=torch.float16) + if bias is not None: + assert bias.size(0) == N, "Incompatible bias dimensions" + + use_bias = bias is not None + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + matmul_gather_kernel_row[grid]( + a, b, out, index_vec, bias, # + M, N, K, index_size, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + out.stride(0), out.stride(1), # + USE_BIAS = use_bias, + ) + return out + + +# This implementation uses a MAX_SIZE parameter to unroll the loop over the K dimension. +# @triton.jit +# def matmul_gather_kernel_row( +# # Pointers to matrices +# a_ptr, b_ptr, c_ptr, index_vec_ptr, bias_ptr, +# # Matrix dimensions +# M, N, K, index_size, # index_size is now a pointer to a device tensor +# # The stride variables +# stride_am, stride_ak, +# stride_bk, stride_bn, +# stride_cm, stride_cn, +# # Meta-parameters +# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +# GROUP_SIZE_M: tl.constexpr, +# USE_BIAS: tl.constexpr, +# MAX_SIZE: tl.constexpr # Maximum possible index_size (compile-time constant) +# ): +# # Map program ids to the block of C to compute. +# pid = tl.program_id(axis=0) +# # Load the runtime value; index_size must be a device pointer. +# in_size = tl.load(index_size) +# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) +# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) +# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# group_id = pid // num_pid_in_group +# first_pid_m = group_id * GROUP_SIZE_M +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) +# pid_n = (pid % num_pid_in_group) // group_size_m + +# # Compute block indices +# offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) +# offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + +# # Initialize accumulator +# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + +# if USE_BIAS: +# bias = tl.load(bias_ptr + offs_bn, mask=offs_bn < N, other=0.0) + +# # Unroll loop over the K dimension using MAX_SIZE (a compile-time constant) +# # max_k = (MAX_SIZE + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K +# max_k = tl.cdiv(MAX_SIZE, BLOCK_SIZE_K) +# for k in range(0, max_k): +# offs_k = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) +# # Mask to ensure we only process valid entries +# # valid_mask = offs_k < in_size +# valid_mask = offs_k < MAX_SIZE + +# # Load indices for the current block with masking +# index_vec = tl.load(index_vec_ptr + offs_k, mask=valid_mask, other=0) + +# # Compute pointers to A and B +# a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) +# b_ptrs = b_ptr + (index_vec[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + +# # Load blocks of A and B, using in_size to mask out-of-bound accesses +# # a = tl.load(a_ptrs, mask=(offs_am[:, None] < M) & (offs_k[None, :] < in_size), other=0.0) +# a = tl.load(a_ptrs, mask=(offs_am[:, None] < M) & (offs_k[None, :] < MAX_SIZE), other=0.0) +# b = tl.load(b_ptrs, mask=(index_vec[:, None] < K) & (offs_bn[None, :] < N), other=0.0) + +# # Accumulate partial results +# accumulator += tl.dot(a, b) + +# if USE_BIAS: +# accumulator += bias[None, :] + +# # Cast accumulator and store the results +# c = accumulator.to(tl.float16) +# offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) +# offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) +# c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] +# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) +# tl.store(c_ptrs, c, mask=c_mask) + +# def gather_matmul_row(a, b, index_vec, index_size, bias = None, MAX_SIZE=None, out = None): +# # Check constraints. +# # a and b are expected to be in row major format. +# # assert a.shape[1] == b.shape[0], "Incompatible dimensions" +# assert a.is_contiguous(), "Matrix A must be contiguous" +# assert MAX_SIZE is not None, "MAX_SIZE () must be provided" + +# M, _ = a.shape +# K, N = b.shape +# # index_size = index_vec.shape[0] +# # Allocates output. +# if out is None: +# out = torch.empty((M, N), device=a.device, dtype=torch.float16) +# if bias is not None: +# assert bias.size(0) == N, "Incompatible bias dimensions" + +# use_bias = bias is not None +# # 1D launch kernel where each block gets its own program. +# grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + +# # Ensure index_size is a device tensor (0-D tensor) +# if not torch.is_tensor(index_size): +# index_size = torch.tensor(index_size, device=a.device, dtype=torch.int32) + +# matmul_gather_kernel_row[grid]( +# a, b, out, index_vec, bias, # +# M, N, K, index_size, # +# a.stride(0), a.stride(1), # +# b.stride(0), b.stride(1), # +# out.stride(0), out.stride(1), # +# USE_BIAS = use_bias, +# MAX_SIZE = MAX_SIZE +# ) +# return out + + + +if __name__ == '__main__': + args = arg_parser() + args.hidden_features = args.in_features * 4 + torch.manual_seed(0) + + # modeling the ffn2 kernel + A = torch.randn(args.batch_size, args.hidden_features, device='cuda', dtype=torch.float16) + B = torch.randn(args.hidden_features, args.in_features, device='cuda', dtype=torch.float16) + + # Construct index_vec: total_neurons is hidden_features. + total_neurons = args.hidden_features + index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32) + active_indices = sparse_index(args.index_size, total_neurons)[0] + index_vec[:args.index_size] = active_indices + if args.index_size < total_neurons: + index_vec[args.index_size:] = 0 + index_size = torch.tensor(args.index_size, device='cuda', dtype=torch.int32) + + # A_select = A[:, index_vec] + # Instead of A_select being [batch_size, index_size] as before, + # we now create a tensor of shape [batch_size, total_neurons] + # where the first args.index_size columns are the gathered values. + A_select = torch.zeros((args.batch_size, total_neurons), device='cuda', dtype=A.dtype) + A_select[:, :args.index_size] = A[:, index_vec[:args.index_size]] + A_select_torch = A[:, active_indices] + + if args.bias: + print("Using bias") + bias = torch.randn(args.in_features, device='cuda', dtype=torch.float16) + else: + print("Not using bias") + bias = None + + print(f"args: {args}") + + # out_b, cublass_time = benchmark_fwd(A, B, bias = None, function = torch.matmul, print_result = args.print_results) + # out_a_gather, gather_gemm_time = benchmark_fwd(A_select, B, bias = bias, index_vec=index_vec, function=gather_matmul_row, print_result = args.print_results) + # out_a_gather_torch,torch_sel_gemm_time = benchmark_fwd(A_select, B, bias = bias, index_vec=index_vec, function=torch_sel_gemm_row, print_result = args.print_results) + + + out_b, cublass_time = cuda_profiler(torch.matmul, A, B) + print(f"Cublass time: {cublass_time:.2f} ms") + out_a_gather, gather_gemm_time = cuda_profiler(gather_matmul_row, A_select, B, bias = bias, index_vec=index_vec, index_size=index_size) + print(f"Gather kernel time: {gather_gemm_time:.2f} ms") + out_a_gather_torch,torch_sel_gemm_time = cuda_profiler(torch_sel_gemm_row, A_select_torch, B, bias = bias, index_vec=active_indices) + print(f"Torch gather time: {torch_sel_gemm_time:.2f} ms") + # C = torch.empty((args.batch_size, args.in_features), device=A.device, dtype=torch.float16) + # out_a_gather_alloc, gather_gemm_time_alloc = benchmark_fwd(A_select, B, index_vec=index_vec, function=gather_matmul_row, print_result = args.print_results, out = C) + + + speedup = cublass_time / gather_gemm_time + + # check results + if args.check_results: + print("Checking results") + # print("Cublass output: ", out_b) + # print("Kernel gather output: ", out_a_gather) + # print("Torch gather output: ", out_a_gather_torch) + + # assert torch.allclose(out_b, out_a_gather, atol=1e-3), "Results do not match" + if torch.allclose(out_a_gather, out_a_gather_torch, atol=0.5): #, "Gathered output does not match torch.matmul output" + print("Results match ✅") + else: + print("Results do not match ❌") + print(f"Speedup: {speedup:.2f}") \ No newline at end of file diff --git a/HybridTensor/triton/gather_gemm_col.py b/HybridTensor/triton/gather_gemm_col.py new file mode 100644 index 0000000000000000000000000000000000000000..18440085562e08b520d27af5b73849b9e2c1a8cf --- /dev/null +++ b/HybridTensor/triton/gather_gemm_col.py @@ -0,0 +1,170 @@ +import torch + +import triton +import triton.language as tl +from .triton_utils import get_autotune_config, benchmark_fwd, torch_sel_gemm_col +from .activations import leaky_relu, relu + +from HybridTensor.utils.utils import sparse_index, arg_parser + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'gather_N', 'K'], +) +@triton.jit +def matmul_gather_kernel_col( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + index_vec_ptr, # Pointer to the index vector for selected columns in B + bias_ptr, + # Matrix dimensions + M, N, K, gather_N, + # Strides + stride_am, stride_ak, # + stride_bk, stride_bn, # B is column-major + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr, # + USE_BIAS: tl.constexpr +): + """Kernel for computing the matmul C = A x B_select, + where B_select contains columns of B selected using index_vec. + A has shape (M, K) in row-major and B has shape (K, N) in column-major. + index_vec is used to select the columns of B for the matmul. + """ + # ----------------------------------------------------------- + # Map program ids to block of C + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(gather_N, BLOCK_SIZE_N) # Adjusted for the gathered columns + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ----------------------------------------------------------- + # Create pointers for the first blocks of A and B + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % gather_N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Pointer arithmetic for A (row-major, K contiguous) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + + # Gather selected column indices from B using the index_vec + gather_idx = tl.load(index_vec_ptr + offs_bn) # Gather indices from the index_vec + b_ptrs = b_ptr + (gather_idx[None, :] * stride_bk + offs_k[:, None] * stride_bn) + + if USE_BIAS: + bias = tl.load(bias_ptr + gather_idx, mask=offs_bn < gather_N, other=0.0) + + # ----------------------------------------------------------- + # Initialize the accumulator for C in FP32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load blocks of A and B with masking for out-of-bounds K + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # Perform dot product and accumulate + accumulator = tl.dot(a, b, accumulator) + # Advance pointers for next iteration over K + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bn # Move along K dimension + + if USE_BIAS: + accumulator += bias[None, :] + + # Optional activation + accumulator = relu(accumulator) + # if ACTIVATION == "leaky_relu": + # accumulator = leaky_relu(accumulator) + # elif ACTIVATION == "relu": + # accumulator = relu(accumulator) + + # Convert the accumulator back to FP16 + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the result into matrix C + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < gather_N) + tl.store(c_ptrs, c, mask=c_mask) + +def gather_matmul_col(a, b, index_vec, bias = None, activations="", out=None): + # Check constraints. + # b is expected to be in column major format. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + if bias is not None: + assert bias.size(0) == b.shape[0], "Incompatible bias dimensions" + + use_bias = bias is not None + + M, K = a.shape + N, K = b.shape + gather_N = index_vec.shape[0] # Number of columns to gather from B, total neuron activations + # Allocates output. + if out is None: + out = torch.empty((M, gather_N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(gather_N, META['BLOCK_SIZE_N']), ) + matmul_gather_kernel_col[grid]( + a, b, out, index_vec, bias, # + M, N, K, gather_N, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + out.stride(0), out.stride(1), # + ACTIVATION='', # + USE_BIAS = use_bias + ) + + return out + +if __name__ == '__main__': + args = arg_parser() + torch.manual_seed(0) + A = torch.randn(args.batch_size, args.in_features, device='cuda', dtype=torch.float16) + B = torch.randn(args.in_features, args.hidden_features, device='cuda', dtype=torch.float16) + B_col_major = B.t().contiguous() + index_vec = sparse_index(args.index_size, args.hidden_features)[0] + if args.bias: + print("Using bias") + bias = torch.randn(args.hidden_features, device='cuda', dtype=torch.float16) + else: + print("No bias") + bias = None + print(f"args: {args}") + print(f"Index size: {args.index_size}, Activated {args.index_size/(args.in_features * 4)*100}% neurons") + + out_b, cublass_time = benchmark_fwd(A, B, bias = None, function = torch.matmul, print_result = args.print_results) + out_a_gather, gather_gemm_time = benchmark_fwd(A, B_col_major, bias = bias, index_vec=index_vec, function=gather_matmul_col, print_result = args.print_results) + out_a_gather_torch,torch_sel_gemm_time = benchmark_fwd(A, B, bias = bias, index_vec=index_vec, function=torch_sel_gemm_col, print_result = args.print_results) + + speedup = cublass_time / gather_gemm_time + + # check results + if args.check_results: + print("Checking results") + print("Kernel gather output: ", out_a_gather) + print("Torch gather output: ", out_a_gather_torch) + # Check if the results are the same + if torch.allclose(out_a_gather, out_a_gather_torch, atol=0.5): #, "Gathered output does not match torch.matmul output" + print("Results match ✅") + else: + print("Results do not match ❌") + # max difference + print("Max difference: ", torch.max(torch.abs(out_a_gather - out_a_gather_torch))) + + print(f"Speedup: {speedup:.2f}") \ No newline at end of file diff --git a/HybridTensor/triton/gather_gemm_row.py b/HybridTensor/triton/gather_gemm_row.py new file mode 100644 index 0000000000000000000000000000000000000000..b5234d68af5c56b65915ed12d0588ae9945ddbef --- /dev/null +++ b/HybridTensor/triton/gather_gemm_row.py @@ -0,0 +1,167 @@ +import torch + +import triton +import triton.language as tl +from .triton_utils import get_autotune_config, benchmark_fwd, torch_sel_gemm_row +from .activations import leaky_relu, relu + +from HybridTensor.utils.utils import sparse_index, arg_parser + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'N', 'index_size'], +) + +@triton.jit +def matmul_gather_kernel_row( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, index_vec_ptr, bias_ptr, + # Matrix dimensions + M, N, K, index_size, + # The stride variables + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, + USE_BIAS: tl.constexpr +): + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Compute block indices + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # Initialize accumulator + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + if USE_BIAS: + # Load the bias vector + bias = tl.load(bias_ptr + offs_bn, mask=offs_bn < N, other=0.0) + + # Loop over the K dimension in blocks + for k in range(0, tl.cdiv(index_size, BLOCK_SIZE_K)): + offs_k = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + # Load indices for the current block + index_vec = tl.load(index_vec_ptr + offs_k, mask=offs_k < index_size, other=0) + + # Compute pointers to A and B + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (index_vec[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # Load blocks of A and B + a = tl.load(a_ptrs, mask=(offs_am[:, None] < M) & (offs_k[None, :] < index_size), other=0.0) + b = tl.load(b_ptrs, mask=(index_vec[:, None] < K) & (offs_bn[None, :] < N), other=0.0) + + # Accumulate partial results + accumulator += tl.dot(a, b) + + if USE_BIAS: + # Add the bias + accumulator += bias[None, :] # Broadcast bias over rows + + # Cast accumulator to desired type and store the result + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + +def gather_matmul_row(a, b, index_vec, bias = None, activations="", out = None): + # Check constraints. + # a and b are expected to be in row major format. + # assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + + M, _ = a.shape + K, N = b.shape + index_size = index_vec.shape[0] + # Allocates output. + if out is None: + out = torch.empty((M, N), device=a.device, dtype=torch.float16) + if bias is not None: + assert bias.size(0) == N, "Incompatible bias dimensions" + + use_bias = bias is not None + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_gather_kernel_row[grid]( + a, b, out, index_vec, bias, # + M, N, K, index_size, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + out.stride(0), out.stride(1), # + ACTIVATION='', # + USE_BIAS = use_bias + ) + + return out + +if __name__ == '__main__': + args = arg_parser() + torch.manual_seed(0) + + # modeling the ffn2 kernel + A = torch.randn(args.batch_size, args.hidden_features, device='cuda', dtype=torch.float16) + B = torch.randn(args.hidden_features, args.in_features, device='cuda', dtype=torch.float16) + + index_vec = sparse_index(args.index_size, args.hidden_features)[0] + A_select = A[:, index_vec] + + if args.bias: + print("Using bias") + bias = torch.randn(args.in_features, device='cuda', dtype=torch.float16) + else: + print("Not using bias") + bias = None + + print(f"args: {args}") + + out_b, cublass_time = benchmark_fwd(A, B, bias = None, function = torch.matmul, print_result = args.print_results) + out_a_gather, gather_gemm_time = benchmark_fwd(A_select, B, bias = bias, index_vec=index_vec, function=gather_matmul_row, print_result = args.print_results) + out_a_gather_torch,torch_sel_gemm_time = benchmark_fwd(A_select, B, bias = bias, index_vec=index_vec, function=torch_sel_gemm_row, print_result = args.print_results) + + # C = torch.empty((args.batch_size, args.in_features), device=A.device, dtype=torch.float16) + # out_a_gather_alloc, gather_gemm_time_alloc = benchmark_fwd(A_select, B, index_vec=index_vec, function=gather_matmul_row, print_result = args.print_results, out = C) + + + speedup = cublass_time / gather_gemm_time + + # check results + if args.check_results: + print("Checking results") + # print("Cublass output: ", out_b) + # print("Kernel gather output: ", out_a_gather) + # print("Torch gather output: ", out_a_gather_torch) + + # assert torch.allclose(out_b, out_a_gather, atol=1e-3), "Results do not match" + if torch.allclose(out_a_gather, out_a_gather_torch, atol=0.5): #, "Gathered output does not match torch.matmul output" + print("Results match ✅") + else: + print("Results do not match ❌") + print(f"Speedup: {speedup:.2f}") \ No newline at end of file diff --git a/HybridTensor/triton/heuristics/__pycache__/autotune_configs.cpython-39.pyc b/HybridTensor/triton/heuristics/__pycache__/autotune_configs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cb1be8cc5cf839c8ee34a58fb0782ce2e64cccb Binary files /dev/null and b/HybridTensor/triton/heuristics/__pycache__/autotune_configs.cpython-39.pyc differ diff --git a/HybridTensor/triton/heuristics/__pycache__/gather_gemm_col_h.cpython-39.pyc b/HybridTensor/triton/heuristics/__pycache__/gather_gemm_col_h.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85b37d4e4a4a191b9ec550b5746c99414a229fe7 Binary files /dev/null and b/HybridTensor/triton/heuristics/__pycache__/gather_gemm_col_h.cpython-39.pyc differ diff --git a/HybridTensor/triton/heuristics/__pycache__/gather_gemm_row_h.cpython-39.pyc b/HybridTensor/triton/heuristics/__pycache__/gather_gemm_row_h.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fdcc711b28078e66c1f3f063a886dc420e99424 Binary files /dev/null and b/HybridTensor/triton/heuristics/__pycache__/gather_gemm_row_h.cpython-39.pyc differ diff --git a/HybridTensor/triton/heuristics/__pycache__/heuristics.cpython-39.pyc b/HybridTensor/triton/heuristics/__pycache__/heuristics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f85dcbef97851b5b13bb0c2967164352d1cc3833 Binary files /dev/null and b/HybridTensor/triton/heuristics/__pycache__/heuristics.cpython-39.pyc differ diff --git a/HybridTensor/triton/heuristics/autotune_configs.py b/HybridTensor/triton/heuristics/autotune_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..c217c9029afd99546e60c0e900c537058d6e1692 --- /dev/null +++ b/HybridTensor/triton/heuristics/autotune_configs.py @@ -0,0 +1,254 @@ +# python -m HybridTensor.triton.heuristics.autotune_configs --in_features 4096 --mode auto + +import os +import re +import json +import subprocess +from pathlib import Path +from tqdm import tqdm +from HybridTensor.utils.utils import arg_parser + + +os.environ["TRITON_PRINT_AUTOTUNING"] = "1" + +# Updated regex to capture all parameters: BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, +# num_warps, num_ctas, num_stages, maxnreg +config_line_pattern = re.compile( + r"best config selected:\s*BLOCK_SIZE_M:\s*(\d+),\s*BLOCK_SIZE_N:\s*(\d+),\s*BLOCK_SIZE_K:\s*(\d+),\s*GROUP_SIZE_M:\s*(\d+),\s*num_warps:\s*(\d+),\s*num_ctas:\s*(\d+),\s*num_stages:\s*(\d+),\s*maxnreg:\s*(\w+);" +) + +args_line_pattern = re.compile( + r"args: Namespace\(batch_size=(\d+), hidden_features=(\d+), in_features=(\d+), seq_len=(\d+), index_size=(\d+)," +) + +kernel_name_pattern = re.compile( + r"Triton autotuning for function (\w+) finished after" +) + +def run_offline_autotune_and_record_1(in_features, mode = "row"): + ''' + Records the best configurations for the gather_gemm_{row/col} kernel for different batch sizes and index sizes. + Old function, does not append to existing configs. + ''' + + batch_sizes = [1, 2, 4, 8, 16, 32, 48, 64] + # batch_sizes = [8] + + + # Index sizes from 10% to 100% of hidden_features in 10% increments + num_neurons = in_features * 4 # hidden_features + index_sizes = [int(num_neurons * pct / 100) for pct in range(5, 101, 5)] + # index_sizes = [1024, 4096, 8192] + + test_cases = [] + for bs in batch_sizes: + for idx_size in index_sizes: + test_cases.append((bs, in_features, idx_size)) + + best_configs = {} + kernel_name = None + + for (batch_size, in_features, index_size) in tqdm(test_cases, desc="Autotuning cases"): + if mode == "row": + cmd = [ + "python", "-m", "HybridTensor.triton.gather_gemm_row", + "--batch_size", str(batch_size), + "--in_features", str(in_features), + "--index_size", str(index_size) + ] + elif mode == "col": + cmd = [ + "python", "-m", "HybridTensor.triton.gather_gemm_col", + "--batch_size", str(batch_size), + "--in_features", str(in_features), + "--index_size", str(index_size) + ] + + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # M = K = N = idx_size = None + M = batch_size + K = in_features + N = in_features * 4 + idx_size = index_size + chosen_config = None + lines = result.stdout.split("\n") + + # Extract M, K, N, idx_size from arguments line + # for line in lines: + # args_match = args_line_pattern.search(line) + # if args_match: + # parsed_batch_size = int(args_match.group(1)) + # parsed_hidden_features = int(args_match.group(2)) + # parsed_in_features = int(args_match.group(3)) + # parsed_seq_len = int(args_match.group(4)) + # parsed_index_size = int(args_match.group(5)) + + # # M = parsed_batch_size + # # K = parsed_in_features + # # N = in_features * 4 + # M = batch_size + # K = in_features + # N = in_features * 4 + # idx_size = index_size + # # N = parsed_hidden_features + # # idx_size = parsed_index_size + # break + + # Extract kernel name and config details + for line in lines: + if kernel_name is None: + kn_match = kernel_name_pattern.search(line) + if kn_match: + kernel_name = kn_match.group(1) + + match = config_line_pattern.search(line) + if match: + BLOCK_SIZE_M = int(match.group(1)) + BLOCK_SIZE_N = int(match.group(2)) + BLOCK_SIZE_K = int(match.group(3)) + GROUP_SIZE_M = int(match.group(4)) + num_warps = int(match.group(5)) + num_ctas = int(match.group(6)) + num_stages = int(match.group(7)) + maxnreg_str = match.group(8) + # Convert maxnreg to None if it's the word 'None' + maxnreg = None if maxnreg_str == "None" else maxnreg_str + + chosen_config = { + "BLOCK_SIZE_M": BLOCK_SIZE_M, + "BLOCK_SIZE_N": BLOCK_SIZE_N, + "BLOCK_SIZE_K": BLOCK_SIZE_K, + "GROUP_SIZE_M": GROUP_SIZE_M, + "num_warps": num_warps, + "num_ctas": num_ctas, + "num_stages": num_stages, + "maxnreg": maxnreg + } + + if chosen_config is None: + print(f"Warning: Could not find chosen config for batch_size={batch_size}, in_features={in_features}, index_size={index_size}") + else: + best_configs[(M, K, N, idx_size)] = chosen_config + print(f"Best config for batch_size={batch_size}, in_features={in_features}, index_size={index_size}: {chosen_config}") + + if kernel_name is None: + kernel_name = "unknown_kernel" + + # Convert tuple keys to string + saveable_configs = {str(k): v for k, v in best_configs.items()} + output_path = Path(f"configs/gemm/best_configs_{kernel_name}_{in_features}.json") + with open(output_path, "w") as f: + json.dump(saveable_configs, f, indent=4) + + print(f"Best configs saved to {output_path}") + +def run_offline_autotune_and_record(in_features, mode="row"): + """ + Records the best configurations for the gather_gemm_{row/col} kernel + for different batch sizes and index sizes. + """ + batch_sizes = [1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 256, 512] + # batch_sizes = [1] + + + # batch_sizes = [2] + num_neurons = in_features * 4 # hidden_features + index_sizes = [int(num_neurons * pct / 100) for pct in range(5, 101, 5)] + + test_cases = [(bs, in_features, idx_size) for bs in batch_sizes for idx_size in index_sizes] + + # Determine output file path. If file exists, load previous configs. + # kernel_name might be unknown initially; use a temporary kernel tag for filename. + temp_kernel_tag = "matmul_gather_kernel_" + mode + output_dir = Path("configs/gemm") + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"best_configs_{temp_kernel_tag}_{in_features}.json" + + best_configs = {} + if output_path.exists(): + with open(output_path, "r") as f: + best_configs = json.load(f) + print(f"Loaded existing configs from {output_path}") + + kernel_name = None + + for (batch_size, in_features, index_size) in tqdm(test_cases, desc="Autotuning cases"): + # Define test key and check if exists + M = batch_size + K = in_features + N = in_features * 4 + key_str = str((M, K, N, index_size)) + if key_str in best_configs: + print(f"Skipping batch_size={batch_size}, in_features={in_features}, index_size={index_size} (config exists)") + continue + + if mode == "row": + cmd = [ + "python", "-m", "HybridTensor.triton.gather_gemm_row", + "--batch_size", str(batch_size), + "--in_features", str(in_features), + "--index_size", str(index_size) + ] + elif mode == "col": + cmd = [ + "python", "-m", "HybridTensor.triton.gather_gemm_col", + "--batch_size", str(batch_size), + "--in_features", str(in_features), + "--index_size", str(index_size) + ] + + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + chosen_config = None + lines = result.stdout.split("\n") + + for line in lines: + if kernel_name is None: + kn_match = kernel_name_pattern.search(line) + if kn_match: + kernel_name = kn_match.group(1) + match = config_line_pattern.search(line) + if match: + chosen_config = { + "BLOCK_SIZE_M": int(match.group(1)), + "BLOCK_SIZE_N": int(match.group(2)), + "BLOCK_SIZE_K": int(match.group(3)), + "GROUP_SIZE_M": int(match.group(4)), + "num_warps": int(match.group(5)), + "num_ctas": int(match.group(6)), + "num_stages": int(match.group(7)), + "maxnreg": None if match.group(8) == "None" else match.group(8) + } + + if chosen_config is None: + print(f"Warning: No config found for batch_size={batch_size}, in_features={in_features}, index_size={index_size}") + else: + best_configs[key_str] = chosen_config + print(f"Saved config for batch_size={batch_size}, in_features={in_features}, index_size={index_size}: {chosen_config}") + + if kernel_name is None: + kernel_name = "unknown_kernel" + + # Update output_path to reflect discovered kernel_name if necessary. + output_path = output_dir / f"best_configs_{kernel_name}_{in_features}.json" + with open(output_path, "w") as f: + json.dump(best_configs, f, indent=4) + + print(f"Best configs saved to {output_path}") + + +if __name__ == "__main__": + # add in_features as an argument + args = arg_parser() + in_features = args.in_features # Adjust as needed + mode = args.mode + + if mode != "auto": + print(f"Running autotuning for in_features={in_features}, mode = {mode}") + run_offline_autotune_and_record(in_features, mode=mode) + else: + print(f"Running autotuning for in_features={in_features}, mode = row") + run_offline_autotune_and_record(in_features, mode="row") + + print(f"Running autotuning for in_features={in_features}, mode = col") + run_offline_autotune_and_record(in_features, mode="col") \ No newline at end of file diff --git a/HybridTensor/triton/heuristics/gather_gemm_col_h.py b/HybridTensor/triton/heuristics/gather_gemm_col_h.py new file mode 100644 index 0000000000000000000000000000000000000000..363f3fe06c6ac7f44720d8677b7fa57061cbdd48 --- /dev/null +++ b/HybridTensor/triton/heuristics/gather_gemm_col_h.py @@ -0,0 +1,276 @@ +# python -m HybridTensor.triton.heuristics.gather_gemm_col_h --batch_size 8 --in_features 4096 --hidden_features 16384 --index_size 512 + + +import torch + +import triton +import triton.language as tl +from HybridTensor.triton.triton_utils import get_autotune_config, benchmark_fwd, torch_sel_gemm_col +from HybridTensor.triton.activations import leaky_relu, relu + +from HybridTensor.utils.utils import sparse_index, arg_parser +from HybridTensor.triton.heuristics.heuristics import HeuristicSelector + +heuristic_dir = "configs/gemm" +selector = HeuristicSelector(heuristic_dir, type="col") + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs + +@triton.jit +def matmul_gather_kernel_col( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + index_vec_ptr, # Pointer to the index vector for selected columns in B + bias_ptr, + # Matrix dimensions + M, N, K, gather_N, + # Strides + stride_am, stride_ak, # + stride_bk, stride_bn, # B is column-major + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr, # + USE_BIAS: tl.constexpr +): + """Kernel for computing the matmul C = A x B_select, + where B_select contains columns of B selected using index_vec. + A has shape (M, K) in row-major and B has shape (K, N) in column-major. + index_vec is used to select the columns of B for the matmul. + """ + # ----------------------------------------------------------- + # Map program ids to block of C + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(gather_N, BLOCK_SIZE_N) # Adjusted for the gathered columns + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ----------------------------------------------------------- + # Create pointers for the first blocks of A and B + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % gather_N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Pointer arithmetic for A (row-major, K contiguous) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + + # Gather selected column indices from B using the index_vec + gather_idx = tl.load(index_vec_ptr + offs_bn) # Gather indices from the index_vec + b_ptrs = b_ptr + (gather_idx[None, :] * stride_bk + offs_k[:, None] * stride_bn) + + if USE_BIAS: + bias = tl.load(bias_ptr + gather_idx, mask=offs_bn < gather_N, other=0.0) + + # ----------------------------------------------------------- + # Initialize the accumulator for C in FP32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load blocks of A and B with masking for out-of-bounds K + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # Perform dot product and accumulate + accumulator = tl.dot(a, b, accumulator) + # Advance pointers for next iteration over K + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bn # Move along K dimension + + if USE_BIAS: + accumulator += bias[None, :] + + # Optional activation + accumulator = relu(accumulator) + # if ACTIVATION == "leaky_relu": + # accumulator = leaky_relu(accumulator) + # elif ACTIVATION == "relu": + # accumulator = relu(accumulator) + + # Convert the accumulator back to FP16 + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the result into matrix C + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < gather_N) + tl.store(c_ptrs, c, mask=c_mask) + +def gather_matmul_col(a, b, index_vec, bias = None, activations="", out=None): + # Check constraints. + # b is expected to be in column major format. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + if bias is not None: + assert bias.size(0) == b.shape[0], "Incompatible bias dimensions" + + use_bias = bias is not None + + M, K = a.shape + N, K = b.shape + gather_N = index_vec.shape[0] # Number of columns to gather from B, total neuron activations + # Allocates output. + if out is None: + out = torch.empty((M, gather_N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + cfg = selector.get_config(M, K, N, gather_N) + + # grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(gather_N, META['BLOCK_SIZE_N']), ) + grid = lambda META: (triton.cdiv(M, cfg['BLOCK_SIZE_M']) * triton.cdiv(gather_N, cfg['BLOCK_SIZE_N']), ) + + matmul_gather_kernel_col[grid]( + a, b, out, index_vec, bias, # + M, N, K, gather_N, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + out.stride(0), out.stride(1), # + ACTIVATION='', # + USE_BIAS = use_bias, + BLOCK_SIZE_M=cfg['BLOCK_SIZE_M'], + BLOCK_SIZE_N=cfg['BLOCK_SIZE_N'], + BLOCK_SIZE_K=cfg['BLOCK_SIZE_K'], + GROUP_SIZE_M=cfg['GROUP_SIZE_M'], + num_warps=cfg['num_warps'], + num_ctas=cfg['num_ctas'], + num_stages=cfg['num_stages'] + ) + + return out + +if __name__ == '__main__': + args = arg_parser() + A = torch.randn(args.batch_size, args.in_features, device='cuda', dtype=torch.float16) + B = torch.randn(args.in_features, args.hidden_features, device='cuda', dtype=torch.float16) + B_col_major = B.t().contiguous() + + + # index_vec = sparse_index(args.index_size, args.hidden_features)[0] + + router_output = torch.rand([args.hidden_features], device='cuda', dtype=torch.float16) + _, index_vec = router_output.topk(args.index_size, dim=0, sorted=False) + + if args.bias: + print("Using bias") + bias = torch.randn(args.hidden_features, device='cuda', dtype=torch.float16) + else: + print("No bias") + bias = None + print(f"args: {args}") + print(f"Index size: {args.index_size}, Activated {args.index_size/(args.in_features * 4)*100}% neurons") + + # Without CUDA Graph + out_b, cublass_time = benchmark_fwd(A, B, bias = None, function = torch.matmul, print_result = args.print_results) + out_a_gather, gather_gemm_time = benchmark_fwd(A, B_col_major, bias = bias, index_vec=index_vec, function=gather_matmul_col, print_result = args.print_results) + out_a_gather_torch,torch_sel_gemm_time = benchmark_fwd(A, B, bias = bias, index_vec=index_vec, function=torch_sel_gemm_col, print_result = args.print_results) + + speedup = cublass_time / gather_gemm_time + + # check results + if args.check_results: + print("Checking results") + print("Kernel gather output: ", out_a_gather) + print("Torch gather output: ", out_a_gather_torch) + # Check if the results are the same + if torch.allclose(out_a_gather, out_a_gather_torch, atol=0.5): #, "Gathered output does not match torch.matmul output" + print("Results match ✅") + else: + print("Results do not match ❌") + # max difference + print("Max difference: ", torch.max(torch.abs(out_a_gather - out_a_gather_torch))) + + print(f"Speedup: {speedup:.2f}") + + # ---------------------------- + # CUDA Graph capture testing for gather_matmul_col + # ---------------------------- + print("Testing CUDA Graph capture...") + + # Warm-up run to compile the kernel + out_cuda = gather_matmul_col(A, B_col_major, index_vec, bias=bias) + torch.cuda.synchronize() + + # Allocate static buffers for CUDA Graph capture + static_out = torch.empty_like(out_cuda) + static_A = A.clone() + static_B = B_col_major.clone() + static_index_vec = index_vec.clone() + if bias is not None: + static_bias = bias.clone() + else: + static_bias = None + + # Warm up static buffers + _ = gather_matmul_col(static_A, static_B, static_index_vec, bias=static_bias, out=static_out) + torch.cuda.synchronize() + + # Capture on a non-default stream + capture_stream = torch.cuda.Stream() + with torch.cuda.stream(capture_stream): + g = torch.cuda.CUDAGraph() + g.capture_begin() + gather_matmul_col(static_A, static_B, static_index_vec, bias=static_bias, out=static_out) + g.capture_end() + torch.cuda.synchronize() + + # Replay the graph and compare with a regular run + g.replay() + torch.cuda.synchronize() + cuda_graph_out = static_out.clone() + + regular_out = gather_matmul_col(A, B_col_major, index_vec, bias=bias) + if torch.allclose(cuda_graph_out, regular_out, atol=1e-3): + print("CUDA Graph output matches regular output") + else: + print("CUDA Graph output does not match regular output") + + # Test with updated inputs + print("Testing CUDA Graph with updated inputs...") + new_A = torch.randn_like(A) + new_B = torch.randn_like(B) + new_B_col_major = new_B.t().contiguous() + if bias is not None: + new_bias = torch.randn_like(bias) + else: + new_bias = None + router_output_new = torch.rand([args.hidden_features], device='cuda', dtype=torch.float16) + _, new_index_vec = router_output_new.topk(args.index_size, dim=0, sorted=False) + + static_A.copy_(new_A) + static_B.copy_(new_B_col_major) + static_index_vec.copy_(new_index_vec) + if static_bias is not None: + static_bias.copy_(new_bias) + + g.replay() + torch.cuda.synchronize() + cuda_graph_out_new = static_out.clone() + + regular_out_new = gather_matmul_col(new_A, new_B_col_major, new_index_vec, bias=new_bias) + if torch.allclose(cuda_graph_out_new, regular_out_new, atol=1e-3): + print("Updated CUDA Graph output matches regular output") + else: + print("Updated CUDA Graph output does not match regular output") + + # Execute the graph + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(args.iterations): + g.replay() + torch.cuda.synchronize() + end.record() + torch.cuda.synchronize() + + elapsed_time = start.elapsed_time(end) / args.iterations + print(f"Average time per genearation (CUDA GRAPH): {elapsed_time:.4f} ms") \ No newline at end of file diff --git a/HybridTensor/triton/heuristics/gather_gemm_row_h.py b/HybridTensor/triton/heuristics/gather_gemm_row_h.py new file mode 100644 index 0000000000000000000000000000000000000000..95dbe009e2e836bb3e4f30dcc26c371612aff781 --- /dev/null +++ b/HybridTensor/triton/heuristics/gather_gemm_row_h.py @@ -0,0 +1,267 @@ +# python -m HybridTensor.triton.heuristics.gather_gemm_row_h --batch_size 8 --in_features 4096 --hidden_features 16384 --index_size 512 + +import torch + +import triton +import triton.language as tl +from HybridTensor.triton.triton_utils import get_autotune_config, benchmark_fwd, torch_sel_gemm_row + +from HybridTensor.utils.utils import sparse_index, arg_parser +from HybridTensor.triton.heuristics.heuristics import HeuristicSelector + +heuristic_dir = "configs/gemm" +selector = HeuristicSelector(heuristic_dir, type="row") + +@triton.jit +def matmul_gather_kernel_row( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, index_vec_ptr, bias_ptr, + # Matrix dimensions + M, N, K, index_size, + # The stride variables + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, # + USE_BIAS: tl.constexpr +): + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Compute block indices + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # Initialize accumulator + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + if USE_BIAS: + # Load the bias vector + bias = tl.load(bias_ptr + offs_bn, mask=offs_bn < N, other=0.0) + + # Loop over the K dimension in blocks + for k in range(0, tl.cdiv(index_size, BLOCK_SIZE_K)): + offs_k = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + # Load indices for the current block + index_vec = tl.load(index_vec_ptr + offs_k, mask=offs_k < index_size, other=0) + + # Compute pointers to A and B + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (index_vec[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # Load blocks of A and B + a = tl.load(a_ptrs, mask=(offs_am[:, None] < M) & (offs_k[None, :] < index_size), other=0.0) + b = tl.load(b_ptrs, mask=(index_vec[:, None] < K) & (offs_bn[None, :] < N), other=0.0) + + # Accumulate partial results + accumulator += tl.dot(a, b) + + if USE_BIAS: + # Add the bias + accumulator += bias[None, :] # Broadcast bias over rows + + # Cast accumulator to desired type and store the result + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + +def gather_matmul_row(a, b, index_vec, bias = None, activations="", out = None): + # Check constraints. + # a and b are expected to be in row major format. + # assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + + M, _ = a.shape + K, N = b.shape + index_size = index_vec.shape[0] + # Allocates output. + if out is None: + out = torch.empty((M, N), device=a.device, dtype=torch.float16) + if bias is not None: + assert bias.size(0) == N, "Incompatible bias dimensions" + + use_bias = bias is not None + + # 1D launch kernel where each block gets its own program. + cfg = selector.get_config(M, N, K, index_size) + + # grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + grid = lambda META: (triton.cdiv(M, cfg['BLOCK_SIZE_M']) * triton.cdiv(N, cfg['BLOCK_SIZE_N']), ) + matmul_gather_kernel_row[grid]( + a, b, out, index_vec, bias, # + M, N, K, index_size, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + out.stride(0), out.stride(1), # + ACTIVATION='', # + USE_BIAS = use_bias, + BLOCK_SIZE_M=cfg['BLOCK_SIZE_M'], + BLOCK_SIZE_N=cfg['BLOCK_SIZE_N'], + BLOCK_SIZE_K=cfg['BLOCK_SIZE_K'], + GROUP_SIZE_M=cfg['GROUP_SIZE_M'], + num_warps=cfg['num_warps'], + num_ctas=cfg['num_ctas'], + num_stages=cfg['num_stages'] + ) + + return out + +if __name__ == '__main__': + args = arg_parser() + # modeling the ffn2 kernel + A = torch.randn(args.batch_size, args.hidden_features, device='cuda', dtype=torch.float16) + B = torch.randn(args.hidden_features, args.in_features, device='cuda', dtype=torch.float16) + + # index_vec = sparse_index(args.index_size, args.hidden_features)[0] + router_output = torch.rand([args.hidden_features], device='cuda', dtype=torch.float16) + _, index_vec = router_output.topk(args.index_size, dim=0, sorted=False) + + A_select = A[:, index_vec] + + if args.bias: + print("Using bias") + bias = torch.randn(args.in_features, device='cuda', dtype=torch.float16) + else: + print("Not using bias") + bias = None + + + out_b, cublass_time = benchmark_fwd(A, B, bias = None, function = torch.matmul, print_result = args.print_results) + out_a_gather, gather_gemm_time = benchmark_fwd(A_select, B, bias = bias, index_vec=index_vec, function=gather_matmul_row, print_result = args.print_results) + out_a_gather_torch,torch_sel_gemm_time = benchmark_fwd(A_select, B, bias = bias, index_vec=index_vec, function=torch_sel_gemm_row, print_result = args.print_results) + + # C = torch.empty((args.batch_size, args.in_features), device=A.device, dtype=torch.float16) + # out_a_gather_alloc, gather_gemm_time_alloc = benchmark_fwd(A_select, B, index_vec=index_vec, function=gather_matmul_row, print_result = args.print_results, out = C) + + + speedup = cublass_time / gather_gemm_time + + # check results + if args.check_results: + print("Checking results") + print("Cublass output: ", out_b) + print("Kernel gather output: ", out_a_gather) + print("Torch gather output: ", out_a_gather_torch) + + # assert torch.allclose(out_b, out_a_gather, atol=1e-3), "Results do not match" + if torch.allclose(out_a_gather, out_a_gather_torch, atol=0.5): #, "Gathered output does not match torch.matmul output" + print("Results match ✅") + else: + print("Results do not match ❌") + print(f"Speedup: {speedup:.2f}") + + + # ---------------------------- + # CUDA Graph capture testing + # ---------------------------- + print("Testing CUDA Graph capture...") + + # Warm-up run to compile the kernel + out_cuda = gather_matmul_row(A_select, B, bias=bias, index_vec=index_vec) + torch.cuda.synchronize() + + # Allocate static buffers for CUDA Graph capture + static_out = torch.empty_like(out_cuda) + static_A = A_select.clone() + static_B = B.clone() + static_index_vec = index_vec.clone() + if bias is not None: + static_bias = bias.clone() + else: + static_bias = None + + # Warm up static buffers + _ = gather_matmul_row(static_A, static_B, bias=static_bias, index_vec=static_index_vec, out=static_out) + torch.cuda.synchronize() + + # Create a non-default stream for CUDA Graph capture. + capture_stream = torch.cuda.Stream() + with torch.cuda.stream(capture_stream): + g = torch.cuda.CUDAGraph() + g.capture_begin() + if static_bias is not None: + gather_matmul_row(static_A, static_B, bias=static_bias, index_vec=static_index_vec, out=static_out) + else: + gather_matmul_row(static_A, static_B, bias=None, index_vec=static_index_vec, out=static_out) + g.capture_end() + # Ensure capture stream is synchronized before replaying on the default stream. + torch.cuda.synchronize() + + # Replay the graph and compare with a regular run + g.replay() + torch.cuda.synchronize() + cuda_graph_out = static_out.clone() + + if args.check_results: + regular_out = gather_matmul_row(A_select, B, bias=bias, index_vec=index_vec) + if torch.allclose(cuda_graph_out, regular_out, atol=1e-3): + print("CUDA Graph output matches regular output") + else: + print("CUDA Graph output does not match regular output") + + + if args.check_results: + # check results with new input + # Now update the static inputs with new data and test again + print("Testing CUDA Graph with updated inputs...") + new_A = torch.randn_like(A_select) + new_B = torch.randn_like(B) + if bias is not None: + new_bias = torch.randn_like(bias) + else: + new_bias = None + + router_output_new = torch.rand([args.hidden_features], device='cuda', dtype=torch.float16) + _, new_index_vec = router_output_new.topk(args.index_size, dim=0, sorted=False) + + static_A.copy_(new_A) + static_B.copy_(new_B) + if static_bias is not None: + static_bias.copy_(new_bias) + static_index_vec.copy_(new_index_vec) + + g.replay() + torch.cuda.synchronize() + cuda_graph_out_new = static_out.clone() + + regular_out_new = gather_matmul_row(new_A, new_B, bias=new_bias, index_vec=new_index_vec) + if torch.allclose(cuda_graph_out_new, regular_out_new, atol=1e-3): + print("Updated CUDA Graph output matches regular output") + else: + print("Updated CUDA Graph output does not match regular output") + + # Execute the graph + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(args.iterations): + g.replay() + torch.cuda.synchronize() + end.record() + torch.cuda.synchronize() + + elapsed_time = start.elapsed_time(end) / args.iterations + print(f"Average time per genearation (CUDA GRAPH): {elapsed_time:.4f} ms") diff --git a/HybridTensor/triton/heuristics/heuristics.py b/HybridTensor/triton/heuristics/heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..73b519ea6001f761174f349fee2efead4fc9018f --- /dev/null +++ b/HybridTensor/triton/heuristics/heuristics.py @@ -0,0 +1,109 @@ +import json +import bisect +import os +import glob +import ast + +class HeuristicSelector: + def __init__(self, config_dir, type = 'row'): + """ + Initialize the HeuristicSelector by loading all configuration files in the specified directory. + + Args: + config_dir (str): Path to the directory containing configuration JSON files. + """ + self.configs = {} # in_features -> (M, K, N) -> sorted list of (index_size, config) + self.type = type + self._load_configs(config_dir) + + + def _load_configs(self, config_dir): + """ + Load all JSON configuration files from the specified directory. + + Args: + config_dir (str): Path to the directory containing configuration JSON files. + """ + # Define the pattern to match configuration files + # pattern = os.path.join(config_dir, 'best_configs_matmul_gather_kernel_col_*.json') + pattern = os.path.join(config_dir, f'best_configs_matmul_gather_kernel_{self.type}_*.json') + config_files = glob.glob(pattern) + + if not config_files: + # raise ValueError(f"No configuration files found in directory: {config_dir}") + print(f"No configuration files found in directory: {config_dir}. Precompile the kernels for faster inference.") + return + + for file_path in config_files: + # Extract in_features from the file name + basename = os.path.basename(file_path) + try: + # Assuming the file name pattern includes in_features as an integer + in_features_str = basename.split('_')[-1].replace('.json', '') + in_features = int(in_features_str) + except (IndexError, ValueError) as e: + raise ValueError(f"Filename {basename} does not match the expected pattern.") from e + + with open(file_path, 'r') as f: + saved_configs = json.load(f) + + # Convert string keys back to tuples and organize by (M, K, N) + temp_dict = {} + for key_str, cfg in saved_configs.items(): + try: + # Safely parse the tuple string + M, K, N, idx_size = ast.literal_eval(key_str) + except (ValueError, SyntaxError) as e: + raise ValueError(f"Invalid key format: {key_str}") from e + + temp_dict.setdefault((M, K, N), []).append((idx_size, cfg)) + + # Sort the index_size for each (M, K, N) + for mkn, entries in temp_dict.items(): + entries.sort(key=lambda x: x[0]) # Sort by index_size + + # Initialize the in_features dictionary if not present + if in_features not in self.configs: + self.configs[in_features] = {} + + self.configs[in_features][mkn] = entries + + def get_config(self, M, K, N, index_size): + """ + Retrieve the best configuration based on the provided parameters. + + Args: + in_features (int): The input features parameter. + M (int): Parameter M. Batch size + K (int): Parameter K. in_features + N (int): Parameter N. total neurons or hidden_features + index_size (int): The desired index size. + + Returns: + dict: The best configuration dictionary. + + Raises: + ValueError: If no configurations are found for the given parameters. + """ + in_features = K + N = in_features * 4 + + # Check if the specified in_features is available + if in_features not in self.configs: + raise ValueError(f"No configurations found for in_features={in_features}") + + mkn = (M, K, N) + if mkn not in self.configs[in_features]: + raise ValueError(f"No configurations found for (M={M}, K={K}, N={N}) with in_features={in_features}") + + entries = self.configs[in_features][mkn] + # entries is sorted by index_size. Find the smallest index_size >= requested index_size. + index_sizes = [e[0] for e in entries] + pos = bisect.bisect_left(index_sizes, index_size) + + if pos < len(entries): + # Found an index_size >= requested index_size + return entries[pos][1] + else: + # No larger index_size found. Fallback to the largest available index_size. + return entries[-1][1] \ No newline at end of file diff --git a/HybridTensor/triton/profile_gather_gemm.py b/HybridTensor/triton/profile_gather_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..afb18a7f2eac6063e7874cf9ca4c06a90e8f3cff --- /dev/null +++ b/HybridTensor/triton/profile_gather_gemm.py @@ -0,0 +1,64 @@ +import torch +import os +import csv + +# from HybridTensor.triton.heuristics.gather_gemm_col_h import gather_matmul_col +from HybridTensor.triton.gather_gemm_col import gather_matmul_col +from HybridTensor.triton.references.gemm import triton_matmul +from HybridTensor.triton.triton_utils import get_autotune_config, benchmark_fwd, torch_sel_gemm_col +from HybridTensor.utils.utils import arg_parser, sparse_index, get_gpu_name +from HybridTensor.utils.profiling import cuda_profiler + +if __name__ == "__main__": + args = arg_parser() + torch.manual_seed(0) + args.hidden_features = args.in_features * 4 + A = torch.randn(args.batch_size, args.in_features, device='cuda', dtype=torch.float16) + B = torch.randn(args.in_features, args.hidden_features, device='cuda', dtype=torch.float16) + B_col_major = B.t().contiguous() + index_vec = sparse_index(args.index_size, args.hidden_features)[0] + if args.bias: + print("Using bias") + bias = torch.randn(args.hidden_features, device='cuda', dtype=torch.float16) + else: + print("No bias") + bias = None + print(f"args: {args}") + print(f"Index size: {args.index_size}, Activated {args.index_size/(args.in_features * 4)*100}% neurons") + + index_sizes_ratio = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + index_sizes = [int(args.hidden_features * ratio) for ratio in index_sizes_ratio] + # index_sizes = [((size + 127) // 128) * 128 for size in index_sizes] # Round up to the nearest multiple of 128 + results = [] + print(f"Index sizes: {index_sizes}") + + linear_layer = torch.nn.Linear(args.in_features, args.hidden_features, bias = args.bias, device = 'cuda', dtype = torch.float16) + out_b, dense_triton_time = cuda_profiler(triton_matmul, A, B, bias=None, warmup_runs=1, timed_runs=10) + out_b, cublass_time = cuda_profiler(linear_layer, A, warmup_runs=1, timed_runs=10) + # out_b, cublass_time = benchmark_fwd(A, B, bias = None, function = torch.linear, print_result = args.print_results) + for idx, index_size in enumerate(index_sizes): + + index_vec = sparse_index(index_size, args.hidden_features)[0] + + out_a_gather, gather_gemm_time = benchmark_fwd(A, B_col_major, bias = bias, index_vec=index_vec, function=gather_matmul_col, print_result = args.print_results) + out_a_gather_torch,torch_sel_gemm_time = benchmark_fwd(A, B, bias = bias, index_vec=index_vec, function=torch_sel_gemm_col, print_result = args.print_results) + + speedup = cublass_time / gather_gemm_time + print(f"Speedup: {speedup:.2f}") + results.append({ + "Neuron Activation": index_sizes_ratio[idx], + "Dense (Cublass) Time (ms)": round(cublass_time, 4), + "Dense (Triton) Time (ms)": round(dense_triton_time, 4), + "Gather GEMM Time (ms)": round(gather_gemm_time, 4), + "Naive(Torch Gather) Time (ms)": round(torch_sel_gemm_time, 4), + "Speedup": round(speedup, 4) + }) + args.results_dir = f"results/triton_kernels/{get_gpu_name()}/" + os.makedirs(args.results_dir, exist_ok=True) + csv_file_path = os.path.join(args.results_dir, f"gather_gemm_profiling_{args.in_features}_{args.batch_size}.csv") + with open(csv_file_path, mode='w', newline='') as file: + writer = csv.DictWriter(file, fieldnames=results[0].keys()) + writer.writeheader() + writer.writerows(results) + + print(f"Results saved to {csv_file_path}") \ No newline at end of file diff --git a/HybridTensor/triton/profile_sel_attn.py b/HybridTensor/triton/profile_sel_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a6aff58c7c6bc11b9d2ddba6ca5c28d5562768 --- /dev/null +++ b/HybridTensor/triton/profile_sel_attn.py @@ -0,0 +1,176 @@ +import torch +import numpy as np +import os +import csv + +from HybridTensor.triton.select_attn_v1 import select_attn +from HybridTensor.triton.references.attn_decode import attention as triton_attn +from flash_attn import flash_attn_with_kvcache +from HybridTensor.utils.utils import arg_parser, generate_random_BH_index, get_gpu_name +from HybridTensor.utils.profiling import cuda_profiler + + + +def perform_attention_with_selection(q, k, v, k_cache, v_cache, batch_head_index, causal, cache_seqlens): + # Expand head_idx so it matches the dimensions of the tensor + # For q: shape (B, 1, num_heads, head_dim) + head_idx = batch_head_index.to(torch.int64) + q_selected = torch.gather( + q, + dim=2, + index=head_idx.unsqueeze(1).unsqueeze(-1).expand(q.size(0), q.size(1), head_idx.size(1), q.size(-1)) + ) + + # For k_cache and v_cache: shape (B, seq_len, num_heads, head_dim) + k_cache_selected = torch.gather( + k_cache, + dim=2, + index=head_idx.unsqueeze(1).unsqueeze(-1).expand(k_cache.size(0), k_cache.size(1), head_idx.size(1), k_cache.size(-1)) + ) + v_cache_selected = torch.gather( + v_cache, + dim=2, + index=head_idx.unsqueeze(1).unsqueeze(-1).expand(v_cache.size(0), v_cache.size(1), head_idx.size(1), v_cache.size(-1)) + ) + + # Similarly for k and v: shape (B, 1, num_heads, head_dim) + k_selected = torch.gather( + k, + dim=2, + index=head_idx.unsqueeze(1).unsqueeze(-1).expand(k.size(0), k.size(1), head_idx.size(1), k.size(-1)) + ) + v_selected = torch.gather( + v, + dim=2, + index=head_idx.unsqueeze(1).unsqueeze(-1).expand(v.size(0), v.size(1), head_idx.size(1), v.size(-1)) + ) + + # Perform attention + out = flash_attn_with_kvcache( + q=q_selected, + k=k_selected, + v=v_selected, + k_cache=k_cache_selected, + v_cache=v_cache_selected, + causal=causal, + cache_seqlens=cache_seqlens + ) + + return out + +def triton_attention_with_selection(q, k_cache, v_cache, batch_head_index, scale_float): + # Expand head_idx so it matches the dimensions of the tensor + # For q: shape (B, 1, num_heads, head_dim) + head_idx = batch_head_index.to(torch.int64) + q_selected = torch.gather( + q, + dim=2, + index=head_idx.unsqueeze(1).unsqueeze(-1).expand(q.size(0), q.size(1), head_idx.size(1), q.size(-1)) + ) + + # For k_cache and v_cache: shape (B, seq_len, num_heads, head_dim) + k_cache_selected = torch.gather( + k_cache, + dim=2, + index=head_idx.unsqueeze(1).unsqueeze(-1).expand(k_cache.size(0), k_cache.size(1), head_idx.size(1), k_cache.size(-1)) + ) + v_cache_selected = torch.gather( + v_cache, + dim=2, + index=head_idx.unsqueeze(1).unsqueeze(-1).expand(v_cache.size(0), v_cache.size(1), head_idx.size(1), v_cache.size(-1)) + ) + + out = triton_attn(q_selected.unsqueeze(2), k_cache_selected.unsqueeze(2), v_cache_selected.unsqueeze(2), scale_float) + # triton_attn_out, dense_triton_time = cuda_profiler(triton_attn, q_s, k_s, v_s, scale_float, warmup_runs=1, timed_runs=10) + + return out + + + +if __name__ == "__main__": + args = arg_parser() + gpu_name = get_gpu_name() + print(f"Starting profiling for select attention") + args.embed_dim = args.in_features + args.results_dir = f"results/triton_kernels/{gpu_name}/" + batch_size = args.batch_size + seq_len = args.seq_len + + attn_topk_list = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] + + device = torch.device(args.device) + num_heads = args.embed_dim // 128 + head_dim = args.embed_dim // num_heads + q = torch.randn(args.batch_size, 1, num_heads, head_dim, dtype = torch.float16).to(device) + k_cache = torch.randn(args.batch_size, args.seq_len, num_heads, head_dim, dtype = torch.float16).to(device) + v_cache = torch.randn(args.batch_size, args.seq_len, num_heads, head_dim, dtype = torch.float16).to(device) + k = torch.randn(args.batch_size, 1, num_heads, head_dim, dtype = torch.float16).to(device) + v = torch.randn(args.batch_size, 1, num_heads, head_dim, dtype = torch.float16).to(device) + print(f"KV cache built") + + # dense flash attention + cache_seqlens = args.seq_len - 1 + causal = False + + attn_topk = 0.5 + print(f"Number of heads selected = {int(num_heads * attn_topk)}") + # # naive select attention + batch_head_index = generate_random_BH_index(args.batch_size, num_heads, int(num_heads * attn_topk), device = device) + + # triton select attention + q_s = q.unsqueeze(2) #torch.randn(B, M, G, H, Kq, dtype=dtype, device=device) + k_s = k_cache.unsqueeze(2) #torch.randn(B, Mk, G, H, Kkv, dtype=dtype, device=device) + v_s = v_cache.unsqueeze(2) #torch.randn(B, Mk, G, H, Kkv, dtype=dtype, device=device) + # Scale for attention + scale_float = 1.0 / np.sqrt(128) + + # Dense triton flash attention + triton_attn_out, dense_triton_time = cuda_profiler(triton_attn, q_s, k_s, v_s, scale_float, warmup_runs=1, timed_runs=10) + print(f'Dense triton attention time: {dense_triton_time} ms') + + + results = [] + + # triton select attention + # naive_out, naive_time = cuda_profiler(triton_attention_with_selection, q, k_cache, v_cache, batch_head_index, scale_float) + # print(f"Naive select attention time: {naive_time:.3f} ms") + + # sel_attn_out, sel_attn_time = cuda_profiler(select_attn, q_s, k_s, v_s, scale_float, batch_head_index, args.seq_len, warmup_runs=1, timed_runs=10) + # print(f'Average time taken for triton selected attention: {sel_attn_time} ms') + + + # Iterate over attn_topk_list + for attn_topk in attn_topk_list: + print(f"Running for attn_topk = {attn_topk}") + num_selected_heads = int(num_heads * attn_topk) + print(f"Number of heads selected = {num_selected_heads}") + + # Generate batch head index for selection + batch_head_index = generate_random_BH_index(args.batch_size, num_heads, num_selected_heads, device=device) + + # Naive select attention + naive_out, naive_time = cuda_profiler(triton_attention_with_selection, q, k_cache, v_cache, batch_head_index, scale_float) + print(f"Naive select attention time: {naive_time:.3f} ms") + + # Triton selected attention + sel_attn_out, sel_attn_time = cuda_profiler(select_attn, q_s, k_s, v_s, scale_float, batch_head_index, args.seq_len, warmup_runs=1, timed_runs=10) + print(f'Average time taken for triton selected attention: {sel_attn_time} ms') + + # Append results + results.append({ + "attn_topk": attn_topk, + "dense_time": dense_triton_time, + "naive_time": naive_time, + "sel_attn_time": sel_attn_time + }) + + # Write results to CSV + os.makedirs(args.results_dir, exist_ok=True) + file_name = f"select_attention_profiling_{args.in_features}_{args.batch_size}_{args.seq_len}.csv" + csv_file = os.path.join(args.results_dir, file_name) + with open(csv_file, mode="w", newline="") as file: + writer = csv.DictWriter(file, fieldnames=["attn_topk", "dense_time", "naive_time", "sel_attn_time"]) + writer.writeheader() + writer.writerows(results) + + print(f"Results saved to {csv_file}") \ No newline at end of file diff --git a/HybridTensor/triton/references/__pycache__/attention_proj_sparse.cpython-310.pyc b/HybridTensor/triton/references/__pycache__/attention_proj_sparse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2b312b7e2f02fa4925a3c604d0dcc4425164dd1 Binary files /dev/null and b/HybridTensor/triton/references/__pycache__/attention_proj_sparse.cpython-310.pyc differ diff --git a/HybridTensor/triton/references/__pycache__/attention_proj_sparse.cpython-39.pyc b/HybridTensor/triton/references/__pycache__/attention_proj_sparse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8963b4efe20a01d013cd4b97b1497f961a118af0 Binary files /dev/null and b/HybridTensor/triton/references/__pycache__/attention_proj_sparse.cpython-39.pyc differ diff --git a/HybridTensor/triton/references/__pycache__/attn_decode.cpython-39.pyc b/HybridTensor/triton/references/__pycache__/attn_decode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5f62e662d1fb1696e08b0c0887688b8111b8575 Binary files /dev/null and b/HybridTensor/triton/references/__pycache__/attn_decode.cpython-39.pyc differ diff --git a/HybridTensor/triton/references/__pycache__/gemm.cpython-39.pyc b/HybridTensor/triton/references/__pycache__/gemm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da3082a1d8bf4a2e89a855b39a1dcf2b82ce40c5 Binary files /dev/null and b/HybridTensor/triton/references/__pycache__/gemm.cpython-39.pyc differ diff --git a/HybridTensor/triton/references/attention_proj_sparse.py b/HybridTensor/triton/references/attention_proj_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..ea75f4dfc1a960fbea9dfd240d90bb5f871d901a --- /dev/null +++ b/HybridTensor/triton/references/attention_proj_sparse.py @@ -0,0 +1,673 @@ +# Deja Vu code + +from typing import Optional + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=8), + triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=16), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=8), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=16), + triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=8), + triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=16), + triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=8), + triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=16), + triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=8), + triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=16), + triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=8), + triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=16), + ], + key=["CACHE_KEY_N"], +) +@triton.jit +def qkv_proj_sparse_kernel( + Y, # Pointers to matrices + A, + X, + HEAD_IDX, + BATCH_IDX, + BIAS, + # Matrix dimensions + N, + batch_size, + CACHE_KEY_N, + stride_a_three, + stride_a_nheads, + stride_a_headdim, + stride_x_batch, + stride_bias_three, + stride_y_batch, + stride_y_three, + # Meta-parameters + BLOCK_B: tl.constexpr, + HEADDIM: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + """ + We will not check that the indices are valid, for performance reason. + - Input X has shape (batch_size, N) + - Weight has shape (3, NHEADS, HEADDIM, N) + - HEAD_IDX has shape (NNZ) + - BATCH_IDX has shape (NNZ, batch_size) + - BIAS has shape (3, NHEADS, HEADDIM) + - Output has shape (batch_size, 3, NHEADS, HEADDIM) + """ + head_id = tl.load(HEAD_IDX + tl.program_id(0)) + qkv_id = tl.program_id(2) + rh = tl.program_id(1) * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + rn = tl.arange(0, BLOCK_N) + rb = tl.arange(0, BLOCK_B) + + BATCH_IDX = BATCH_IDX + tl.program_id(0) * batch_size + rb + batch_idx = tl.load(BATCH_IDX, mask=rb < batch_size, other=-1) + + A = ( + A + + qkv_id * stride_a_three + + head_id * stride_a_nheads + + (rh[None, :] * stride_a_headdim + rn[:, None]) + ) + X = X + (batch_idx[:, None] * stride_x_batch + rn[None, :]) + + acc = tl.zeros((BLOCK_B, BLOCK_HEAD), dtype=tl.float32) + for n in range(N, 0, -BLOCK_N): + x = tl.load(X, mask=batch_idx[:, None] >= 0, other=0.0) + a = tl.load(A) + acc += tl.dot(x, a) + A += BLOCK_N + X += BLOCK_N + if HAS_BIAS: + bias = tl.load(BIAS + qkv_id * stride_bias_three + head_id * HEADDIM + rh).to( + tl.float32 + ) + acc += bias[None, :] + + # rematerialize rb to save registers + rb = tl.arange(0, BLOCK_B) + # write back result + Y = ( + Y + + qkv_id * stride_y_three + + head_id * HEADDIM + + (batch_idx[:, None] * stride_y_batch + rh[None, :]) + ) + tl.store(Y, acc, mask=batch_idx[:, None] >= 0) + + +def qkv_proj_sparse( + x: torch.Tensor, + weight: torch.Tensor, + head_idx: torch.Tensor, + batch_idx: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Compute y = torch.einsum("bm,thdm->bthd", x, weight), but only for the active heads in @head_idx + and only for the batch indices in @batch_idx. Negative indices in batch_idx will be ignored. + :param x: input tensor, (batch_size, hidden_dim) + :param weight: weight matrix, (3, nheads, head_dim, hidden_dim) + :param head_idx: int32, (nnz,) + :param batch_idx: int32, (nnz, batch_size). Negative indices are ignored. + :param bias: indices, (3, nheads, head_dim) + :return: result tensor, (batch_size, 3, nheads, head_dim) + """ + three, nheads, head_dim, hidden_dim = weight.shape + assert three == 3 + assert head_dim in [32, 64, 128] + assert hidden_dim % 128 == 0 + batch_size, _ = x.shape + # assert batch_size <= 16 + assert x.shape == (batch_size, hidden_dim) + n_active = head_idx.shape[0] + assert n_active <= nheads + assert head_idx.shape == (n_active,) + assert head_idx.dtype == torch.int32 + assert batch_idx.shape == (n_active, batch_size) + assert batch_idx.dtype == torch.int32 + x = x.contiguous() + if weight.stride(-1) > 1: + weight = weight.contiguous() + head_idx = head_idx.contiguous() + batch_idx = batch_idx.contiguous() + assert ( + x.dtype == weight.dtype + ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" + if bias is not None: + bias = bias.contiguous() + assert bias.shape == ( + 3, + nheads, + head_dim, + ), "Incompatible dimensions in between weight and bias" + assert ( + x.dtype == bias.dtype + ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" + + output = torch.empty( + batch_size, 3, nheads, head_dim, device=x.device, dtype=x.dtype + ) + + BLOCK_HEAD = 32 + # 1D launch kernel where each head gets its own program. + grid = lambda META: (n_active, head_dim // BLOCK_HEAD, 3) # noqa + + qkv_proj_sparse_kernel[grid]( + output, # data ptrs + weight, + x, + head_idx, + batch_idx, + bias, + hidden_dim, # shapes + batch_size, + hidden_dim // 256, # key for triton cache (limit number of compilations) + weight.stride(0), # strides + weight.stride(1), + weight.stride(2), + x.stride(0), + bias.stride(0), + output.stride(0), + output.stride(1), + 16, # Can't use kwargs because auto-tuner requires args + head_dim, + # 64, + BLOCK_HEAD=BLOCK_HEAD, + HAS_BIAS=bias is not None, + # num_warps=2, + # num_stages=16 + # num_stages=16 + ) + + return output + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=8), + triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=16), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=8), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=16), + triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=8), + triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=16), + triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=8), + triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=16), + triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=8), + triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=16), + triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=8), + triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=16), + ], + key=["CACHE_KEY_N"], +) +@triton.jit +def out_proj_sparse_kernel( + Y, # Pointers to matrices + A, + X, + HEAD_IDX, + BIAS, + # Matrix dimensions + N, + n_active_heads, + batch_size, # actual batch size + CACHE_KEY_N, + stride_a_n, + stride_x_batch, + stride_y_batch, + # Meta-parameters + BLOCK_B: tl.constexpr, + HEADDIM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + """ + We will not check that the indices are valid, for performance reason. + - Input X has shape (BLOCK_B, NHEADS, HEADDIM) + - Weight has shape (N, NHEADS, HEADDIM) + - HEAD_IDX has shape (NNZ) + - BIAS has shape (N) + - Output has shape (BLOCK_B, N) + """ + start_n = tl.program_id(0) * BLOCK_N + rh = tl.arange(0, HEADDIM) + rn = tl.arange(0, BLOCK_N) + rb = tl.arange(0, BLOCK_B) + + A = A + start_n * stride_a_n + (rn[None, :] * stride_a_n + rh[:, None]) + X = X + (rb[:, None] * stride_x_batch + rh[None, :]) + + acc = tl.zeros((BLOCK_B, BLOCK_N), dtype=tl.float32) + for h in range(n_active_heads): + head_id = tl.load(HEAD_IDX + h) + x = tl.load(X + head_id * HEADDIM, mask=rb[:, None] < batch_size, other=0.0) + a = tl.load(A + head_id * HEADDIM) + acc += tl.dot(x, a) + if HAS_BIAS: + bias = tl.load(BIAS + start_n + rn).to(tl.float32) + acc += bias[None, :] + + # rematerialize rn and rn to save registers + rb = tl.arange(0, BLOCK_B) + # write back result + Y = Y + start_n + (rb[:, None] * stride_y_batch + rn[None, :]) + tl.store(Y, acc, mask=rb[:, None] < batch_size) + + +def out_proj_sparse( + x: torch.Tensor, + weight: torch.Tensor, + head_idx: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Compute y = torch.einsum("bhd,nhd->bn", x, weight), but only for the active heads in @head_idx. + :param x: input tensor, (batch, nheads, head_dim) + :param weight: weight matrix, (hidden_dim, nheads, head_dim) + :param head_idx: int32, (nnz,) + :param bias: indices, (hidden_dim) + :return: result tensor, (batch, hidden_dim) + """ + hidden_dim, nheads, head_dim = weight.shape + assert head_dim in [32, 64, 128] + assert hidden_dim % 128 == 0 + batch, _, _ = x.shape + assert x.shape == (batch, nheads, head_dim) + n_active = head_idx.shape[0] + assert head_idx.shape == (n_active,) + assert head_idx.dtype == torch.int32 + assert batch <= 16 + x = x.contiguous() + weight = weight.contiguous() + assert ( + x.dtype == weight.dtype + ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" + if bias is not None: + bias = bias.contiguous() + assert bias.shape == ( + hidden_dim, + ), "Incompatible dimensions in between weight and bias" + assert ( + x.dtype == bias.dtype + ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" + + output = torch.empty(batch, hidden_dim, device=x.device, dtype=x.dtype) + + # 1D launch kernel where each head gets its own program. + grid = lambda META: (triton.cdiv(hidden_dim, META["BLOCK_N"]),) # noqa + + out_proj_sparse_kernel[grid]( + output, # data ptrs + weight, + x, + head_idx, + bias, + hidden_dim, # shapes + n_active, + batch, + hidden_dim // 256, # key for triton cache (limit number of compilations) + weight.stride(0), # strides + x.stride(0), + output.stride(0), + 16, # Can't use kwargs because auto-tuner requires args + head_dim, + # 64, + HAS_BIAS=bias is not None, + # num_warps=2, + # num_stages=10 + # num_stages=8 + ) + + return output + + +@triton.jit +def v_cache_copy_sparse_kernel( + V, # Pointers to matrices + V_CACHE, + HEAD_IDX, + SEQLEN_IDX, + PADDING, + # Matrix dimensions + nnz_seqlen, + nheads, + seqlen, + head_dim, + stride_v_cache_nheads, + stride_v_cache_seqlen, + stride_v_seqlen, + stride_v_nheads, + # Meta-parameters + BLOCK_B: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + HAS_PADDING: tl.constexpr, +): + """ + We will not check that the indices are valid, for performance reason. + - Input V has shape (nnz_seqlen, nheads, headdim) + - V_CACHE has shape (1, nheads, seqlen, headdim) + - HEAD_IDX has shape (NNZ) + - SEQLEN_IDX has shape (NNZ, nnz_seqlen) + """ + head_id = tl.load(HEAD_IDX + tl.program_id(0)) + rh = tl.program_id(1) * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + rb = tl.arange(0, BLOCK_B) + + if HAS_PADDING: + padding = tl.load(PADDING) + else: + padding = 0 + + SEQLEN_IDX = SEQLEN_IDX + tl.program_id(0) * nnz_seqlen + rb + seqlen_idx = tl.load(SEQLEN_IDX, mask=rb < nnz_seqlen, other=-1) + + V = ( + V + + head_id * stride_v_nheads + + (seqlen_idx[:, None] * stride_v_seqlen + rh[None, :]) + ) + V_CACHE = ( + V_CACHE + + head_id * stride_v_cache_nheads + + ((seqlen_idx + padding)[:, None] * stride_v_cache_seqlen + rh[None, :]) + ) + + v = tl.load(V, mask=seqlen_idx[:, None] >= 0, other=0.0) + tl.store(V_CACHE, v, mask=seqlen_idx[:, None] >= 0) + + +def v_cache_copy_sparse( + v: torch.Tensor, + v_cache: torch.Tensor, + head_idx: torch.Tensor, + seqlen_idx: torch.Tensor, + padding: torch.Tensor = None, +) -> None: + """ + :param v: input tensor, (nnz_seqlen, nheads, head_dim) + :param v_cache: input tensor, (1, nheads, seqlen, head_dim) + :param head_idx: int32, (nnz,) + :param seqlen_idx: int32, (nnz, nnz_seqlen). Negative indices are ignored. + :param padding: int32, (1). Padding is added to indices in seqlen_idx before writing to v_cache + """ + nnz_seqlen, nheads, head_dim = v.shape + _, _, seqlen, _ = v_cache.shape + assert head_dim in [32, 64, 128] + assert nnz_seqlen <= 16 + assert v_cache.shape == (1, nheads, seqlen, head_dim) + n_active = head_idx.shape[0] + assert n_active <= nheads + assert head_idx.shape == (n_active,) + assert head_idx.dtype == torch.int32 + assert seqlen_idx.shape == (n_active, nnz_seqlen) + assert seqlen_idx.dtype == torch.int32 + v = v.contiguous() + assert v_cache.stride(-1) == 1 + head_idx = head_idx.contiguous() + seqlen_idx = seqlen_idx.contiguous() + assert ( + v.dtype == v_cache.dtype + ), f"v and v_cache must have the same dtype, got {v.dtype} and {v_cache.dtype}" + if padding is not None: + assert padding.shape == (1,) + assert padding.dtype == torch.int32 + + BLOCK_HEAD = 32 + # 1D launch kernel where each head gets its own program. + grid = lambda META: (n_active, head_dim // BLOCK_HEAD) # noqa + + v_cache_copy_sparse_kernel[grid]( + v, # data ptrs + v_cache, + head_idx, + seqlen_idx, + padding, + nnz_seqlen, # shapes + nheads, + seqlen, + head_dim, + v_cache.stride(1), # strides + v_cache.stride(2), + v.stride(0), + v.stride(1), + 16, # Can't use kwargs because auto-tuner requires args + BLOCK_HEAD=BLOCK_HEAD, + HAS_PADDING=padding is not None + # num_warps=2, + # num_stages=16 + # num_stages=16 + ) + + +@triton.jit +def k_cache_copy_sparse_kernel( + K, # Pointers to matrices + K_CACHE, + HEAD_IDX, + SEQLEN_IDX, + PADDING, + # Matrix dimensions + nnz_seqlen, + nheads, + seqlen, + head_dim, + stride_k_cache_nheads, + stride_k_cache_headdimpack, + stride_k_cache_seqlen, + stride_k_seqlen, + stride_k_nheads, + # Meta-parameters + PACKSIZE: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_HEAD_PACK: tl.constexpr, + HAS_PADDING: tl.constexpr, +): + """ + We will not check that the indices are valid, for performance reason. + - Input K has shape (nnz_seqlen, nheads, headdim) + - K_CACHE has shape (1, nheads, headdim / PACKSIZE, seqlen, PACKSIZE) + - HEAD_IDX has shape (NNZ) + - SEQLEN_IDX has shape (NNZ, nnz_seqlen) + """ + head_id = tl.load(HEAD_IDX + tl.program_id(0)) + rh = tl.program_id(1) * BLOCK_HEAD_PACK + tl.arange(0, BLOCK_HEAD_PACK) + rp = tl.arange(0, PACKSIZE) + rb = tl.arange(0, BLOCK_B) + + if HAS_PADDING: + padding = tl.load(PADDING) + else: + padding = 0 + + SEQLEN_IDX = SEQLEN_IDX + tl.program_id(0) * nnz_seqlen + rb + seqlen_idx = tl.load(SEQLEN_IDX, mask=rb < nnz_seqlen, other=-1) + + K = ( + K + + head_id * stride_k_nheads + + ( + seqlen_idx[:, None, None] * stride_k_seqlen + + rh[None, :, None] * PACKSIZE + + rp[None, None, :] + ) + ) + K_CACHE = ( + K_CACHE + + head_id * stride_k_cache_nheads + + ( + (seqlen_idx + padding)[:, None, None] * stride_k_cache_seqlen + + rh[None, :, None] * stride_k_cache_headdimpack + + rp[None, None, :] + ) + ) + + k = tl.load(K, mask=seqlen_idx[:, None, None] >= 0, other=0.0) + tl.store(K_CACHE, k, mask=seqlen_idx[:, None, None] >= 0) + + +def k_cache_copy_sparse( + k: torch.Tensor, + k_cache: torch.Tensor, + head_idx: torch.Tensor, + seqlen_idx: torch.Tensor, + padding: torch.Tensor = None, +) -> None: + """ + :param k: input tensor, (nnz_seqlen, nheads, head_dim) + :param k_cache: input tensor, (1, nheads, headdim / PACKSIZE, seqlen, PACKSIZE), where + PACKSIZE = 8 if fp16/bf16 and 4 if fp32. + :param head_idx: int32, (nnz,) + :param seqlen_idx: int32, (nnz, nnz_seqlen). Negative indices are ignored. + :param padding: int32, (1). Padding is added to indices in seqlen_idx before writing to v_cache + """ + assert k.dtype in [torch.float32, torch.float16, torch.bfloat16] + packsize = 4 if k.dtype == torch.float32 else 8 + nnz_seqlen, nheads, head_dim = k.shape + _, _, _, seqlen, _ = k_cache.shape + assert head_dim in [32, 64, 128] + assert nnz_seqlen <= 16 + assert k_cache.shape == (1, nheads, head_dim // packsize, seqlen, packsize) + n_active = head_idx.shape[0] + assert n_active <= nheads + assert head_idx.shape == (n_active,) + assert head_idx.dtype == torch.int32 + assert seqlen_idx.shape == (n_active, nnz_seqlen) + assert seqlen_idx.dtype == torch.int32 + k = k.contiguous() + assert k_cache.stride(-1) == 1 + head_idx = head_idx.contiguous() + seqlen_idx = seqlen_idx.contiguous() + assert ( + k.dtype == k_cache.dtype + ), f"k and k_cache must have the same dtype, got {k.dtype} and {k_cache.dtype}" + if padding is not None: + assert padding.shape == (1,) + assert padding.dtype == torch.int32 + + BLOCK_HEAD = 32 + # 1D launch kernel where each head gets its own program. + grid = lambda META: (n_active, head_dim // BLOCK_HEAD) # noqa + + k_cache_copy_sparse_kernel[grid]( + k, # data ptrs + k_cache, + head_idx, + seqlen_idx, + padding, + nnz_seqlen, # shapes + nheads, + seqlen, + head_dim, + k_cache.stride(1), # strides + k_cache.stride(2), + k_cache.stride(3), + k.stride(0), + k.stride(1), + packsize, # Can't use kwargs because auto-tuner requires args + 16, + BLOCK_HEAD_PACK=BLOCK_HEAD // packsize, + HAS_PADDING=padding is not None + # num_warps=2, + # num_stages=16 + # num_stages=16 + ) + + +if __name__ == "__main__": + from src.utils.benchmark import pytorch_profiler + from einops import rearrange, repeat + + torch.manual_seed(0) + device = "cuda" + dtype = torch.float16 + nheads = 12 + hidden_dim = 12 * 1024 + + head_dim = 128 + batch_size = 16 + # n_active_heads = nheads + n_active_heads = nheads // 3 + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + weight = torch.randn(3, nheads, head_dim, hidden_dim, device=device, dtype=dtype) + bias = torch.randn(3, nheads, head_dim, device=device, dtype=dtype) + # head_idx = torch.arange(nheads, dtype=torch.int32, device=device) + # head_idx = torch.arange(nheads // 2, dtype=torch.int32, device=device) + nheads // 2 + head_idx = torch.randperm(nheads, device=device, dtype=torch.int32)[:n_active_heads] + # batch_idx = repeat(torch.arange(batch_size, dtype=torch.int32, device=device), + # "b -> h b", h=n_active_heads).contiguous() + # batch_idx = batch_idx.flip(-1).contiguous() + batch_idx = torch.stack( + [ + torch.randperm(batch_size, dtype=torch.int32, device=device) + for _ in range(n_active_heads) + ] + ) + # batch_mask = torch.randint(0, 2, (n_active_heads, batch_size), dtype=torch.bool, device=device) + # batch_idx[batch_mask] = -1 + out = qkv_proj_sparse(x, weight, head_idx, batch_idx, bias) + out_ref = torch.einsum("thdn,bn->bthd", weight, x) + bias + # out_tt = rearrange(triton.ops.matmul(x, rearrange(weight, "t h d n -> (t h d) n").t()), "b (t h d) -> b t h d", t=3, h=nheads) + bias + print((out - out_ref)[:, :, head_idx].abs().max()) + # print((out - out_tt)[:, :, head_idx].abs().max()) + # breakpoint() + + # pytorch_profiler(qkv_proj_sparse, x, weight, head_idx, batch_idx, bias) + # x_tmp = torch.randn(1, hidden_dim, device=device, dtype=dtype) + # pytorch_profiler(torch.einsum, 'thdm,bm->bthd', weight, x_tmp) + # # pytorch_profiler(torch.einsum, 'bn,hdn->bhd', x, weight) + # # pytorch_profiler(torch.sum, weight, dim=-1) + # # pytorch_profiler(triton.ops.matmul, x, rearrange(weight, "t h d n -> (t h d) n").t()) + + nheads = 12 + head_dim = 128 + x = torch.randn(2, nheads, head_dim, device=device, dtype=dtype) + weight = torch.randn(hidden_dim, nheads, head_dim, device=device, dtype=dtype) + bias = torch.randn(hidden_dim, device=device, dtype=dtype) + # head_idx = torch.arange(nheads, dtype=torch.int32, device=device) + head_idx = torch.randperm(nheads, device=device, dtype=torch.int32)[:n_active_heads] + out = out_proj_sparse(x, weight, head_idx, bias) + out_ref = torch.einsum("nhd,bhd->bn", weight[:, head_idx], x[:, head_idx]) + bias + print((out - out_ref).abs().max()) + + # pytorch_profiler(out_proj_sparse, x, weight, head_idx, bias) + # # pytorch_profiler(torch.einsum, 'bhd,nhd->bn', x, weight) + # pytorch_profiler(torch.einsum, 'bd,nd->bn', + # rearrange(x, 'b h d -> b (h d)'), rearrange(weight, 'n h d -> n (h d)')) + + batch_size = 16 + seqlen = 32 + v = torch.randn(batch_size, nheads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros(1, nheads, seqlen, head_dim, device=device, dtype=dtype) + padding = torch.tensor([16], dtype=torch.int32, device=device) + v_cache_copy_sparse(v, v_cache, head_idx, batch_idx, padding) + # pytorch_profiler(v_cache_copy_sparse, v, v_cache, head_idx, batch_idx) + v[batch_idx[0], 0, 0] + v_cache[0, 0, batch_idx[0] + padding, 0] + + packsize = 4 if dtype == torch.float32 else 8 + k = torch.randn(batch_size, nheads, head_dim, device=device, dtype=dtype) + k_cache = torch.zeros( + 1, nheads, head_dim // packsize, seqlen, packsize, device=device, dtype=dtype + ) + k_cache_copy_sparse(k, k_cache, head_idx, batch_idx, padding) + k_cache_og = rearrange(k_cache, "1 h d s p -> 1 h s (d p)") + pytorch_profiler(k_cache_copy_sparse, k, k_cache, head_idx, batch_idx) + k[batch_idx[0], 0, 0] + k_cache_og[0, 0, batch_idx[0] + padding, 0] diff --git a/HybridTensor/triton/references/attn_decode.py b/HybridTensor/triton/references/attn_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..89af1d8220e8852fec7bdc5b1f141346220ee096 --- /dev/null +++ b/HybridTensor/triton/references/attn_decode.py @@ -0,0 +1,819 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple +import pytest +import torch +import sys + +import triton +import triton.language as tl + +def _strides(x: torch.Tensor, *stride_names: str): + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Seq_len, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + Z, + N_CTX_Q, + N_CTX_K, + BLOCK_N_PER_SPLIT, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr = 1, + N_GROUPS: tl.constexpr = 1, +): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + """ + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), + f"Only 4-bit quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + + QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H * G) # batch + off_h = (off_zhg // G) % H # head + off_g = off_zhg % G # group + splitk_idx = tl.program_id(2) + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z) + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg + # Additional shift by 1 along the last dimension in the quantized case, since + # the first element along that dim contains packed quantization coefficients. + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * QUANTIZED * N_GROUPS, + shape=(hi, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + + if QUANTIZED: + # Pointers to quantization coefficients + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, 1), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0,)) + q = (q * qk_scale).to(q.dtype) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + Q.dtype.element_ty, + 0, + ) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # noqa: F821 + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p, v) + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + if PACKED_PER_VAL > 1: + K_scale_shift_block_ptr = tl.advance( + K_scale_shift_block_ptr, (0, BLOCK_N) + ) + V_scale_shift_block_ptr = tl.advance( + V_scale_shift_block_ptr, (BLOCK_N, 0) + ) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0,), + ) + # Write metadata for split-K reduction + Metadata_ptr = ( + Metadata + + off_zhg * stride_mzhg + + splitk_idx * stride_ms + + start_m * BLOCK_M + + tl.arange(0, BLOCK_M) + ) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + +@triton.jit +def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block. In case of int4-quantized K/V, + # dequantize them after loading. If quantization is group-wise, + # use group_id to advance the pointers to the current group. + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()) + + if PACKED_PER_VAL > 1: + # K/V are quantized, load quantization coefficients and dequantize + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load( + K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () + ) + v_scale_shift = tl.load( + V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () + ) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = ( + x_[:, None, :] >> offsets[None, :, None] + ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil:tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k:tl.constexpr, + splitK_pow2:tl.constexpr, + use_mask:tl.constexpr, +): + off_zhg = tl.program_id(0) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + off_k = tl.program_id(2) + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = ( + Metadata + + stride_mzhg * off_zhg + + spk_idx * stride_ms + + off_m * stride_mm + ) + + o_ptr = ( + Out_splitK + + off_zhg * stride_osk_zhg + + stride_osk_m * off_m + + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + + kidx[None, :] * stride_osk_k + ) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:,None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + acc_out = tl.sum(acc, axis=0) / g_sum + Out_ptr = ( + Out + + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + off_k * BLOCK_SIZE + + tl.arange(0, BLOCK_SIZE) + ) + tl.store(Out_ptr, acc_out) + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat( + [scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1 + ) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[...,0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale = scale_shift_ui8[...,0:2].view(torch.float16) + shift = scale_shift_ui8[...,2:4].view(torch.float16) + + kv_ui8 = k_ui8[...,ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out[...,::2] = k1_f16 + out[...,1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + + +class _attention(torch.autograd.Function): + + OPERATOR = _fwd_kernel_splitK + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } + SUPPORTED_MAX_K = 128 + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "triton_splitKF" + + @staticmethod + def forward(cls, q, k, v, scale_float): + + cls.SPLIT_K: Optional[int] = None + cls.BLOCK_M = 16 + cls.BLOCK_N = 64 + + cls.NUM_GROUPS = 1 # Default quantization is row-wise + + # attn_bias = inp.attn_bias + seq_len = None + + # Transpose in the case of MQA/GQA + mqa_swap_seqlen_head = False + if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + mqa_swap_seqlen_head = True + assert q.shape[1] == 1 + q = q.transpose(1, 3) + k = k[:, :, :, :1] + v = v[:, :, :, :1] + + if k.dtype == torch.int32: + # Quantized K/V + PACKED_PER_VAL = 8 + Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8 + else: + Lk = k.shape[-1] + PACKED_PER_VAL = 1 + + B, Mk, G, H, Kkv = k.shape + B, M, G, H, Kq = q.shape + assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" + # print(f"B = {B}, M = {M}, G = {G}, H = {H}, Kkv = {Kkv}, Kq = {Kq}") + + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = get_split_k(B, G, H, Mk) + + M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M + o_splitk = torch.empty( + [B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device + ) + metadata = torch.empty( + [B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device + ) + lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k) + + num_warps = 1 + split_size = (Mk + split_k - 1) // split_k + use_seq_len = seq_len is not None + + # print(f"B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}, num_of_wgs = {G * G * H * split_k}") + + _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=scale_float, + Out_splitK=o_splitk, + Metadata=metadata, + Seq_len=seq_len, + **_strides(q, "qz", "qm", "qg", "qh", "qk"), + **_strides(k, "kz", "kn", "kg", "kh", "kk"), + **_strides(v, "vz", "vn", "vg", "vh", "vk"), + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + Z=B, + H=H, + G=G, + N_CTX_Q=M, + N_CTX_K=Mk, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len, + USE_SEQ_LEN=use_seq_len, + num_warps=num_warps, + num_stages=1, + PACKED_PER_VAL=PACKED_PER_VAL, + N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, + ) + + if mqa_swap_seqlen_head: + out = torch.empty( + (B, H, G, M, Kq), device=q.device, dtype=q.dtype + ).transpose(1, 3) + else: + out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if B * G * H * M >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert out.shape[-1] % k_block_num == 0 + k_block_size = out.shape[-1] // k_block_num + grid = (B * G * H, M, k_block_num) + _splitK_reduce[grid]( + o_splitk, + metadata, + out, + lse, + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=M_ceil, + BLOCK_SIZE=k_block_size, + G=G, + H=H, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + num_warps=4 + ) + + lse = lse.reshape([B, G, H, M]) + if mqa_swap_seqlen_head: + # H/M dimensions have been swapped + out = out.transpose(1, 3) + lse = lse.transpose(2, 3) + if q.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if Mk == 0: + out.zero_() + if mqa_swap_seqlen_head: + out = out.reshape(B, -1, M * G, Kq).transpose(1, 2).contiguous() + else: + out = out.reshape(B, H * G, -1, Kq).contiguous() + + return out + +attention = _attention.apply + +def get_input_shapes(): + cases = [ + (max(1, 2 ** (16 - i)), 1, 2**i, 16, 1, 128) + for i in range(8, 18) + ] + [ + (max(1, 2 ** (16 - i)), 1, 2**i, 16, 2, 128) + for i in range(8, 18) + ] + + return cases + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', + get_input_shapes()) +def test_op_fwd(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(20) + q = ( + torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + scale = 1 / K**0.5 + tri_out = attention(q, k, v, scale) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-3, rtol=0) + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', + get_input_shapes()) +def test_op_fwd_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(2) + q = ( + torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, device="cuda") + .normal_(mean=1.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=1.0, std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=1.0, std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + + num_groups = 1 + quant_k = ( + quantize_kv_int4(k, num_groups=num_groups) + .contiguous() + .view(torch.int32) + ) + quant_v = ( + quantize_kv_int4(v, num_groups=num_groups) + .contiguous() + .view(torch.int32) + ) + scale = 1 / K**0.5 + tri_out = attention(q, quant_k, quant_v, scale) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2.1e-2, rtol=0) + + # since quantization introduces rounding error, use the + # dequantized kv as inputs to the ref implementation to reduce + # the tolerance to 1e-3 + dqk = dequantize_kv_fp16(quant_k, num_groups=num_groups) + dqv = dequantize_kv_fp16(quant_v, num_groups=num_groups) + dqk = dqk.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dqv = dqv.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1) + dq_ref_out = dq_attn @ dqv + torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0) + + +def test_quantization(): + a = torch.randn((2, 4, 32), dtype=torch.float16, device='cuda') + qa = quantize_kv_int4(a, num_groups=4) + dqa = dequantize_kv_fp16(qa, num_groups=4) + torch.testing.assert_close(a, dqa, atol=1.5e-1, rtol=1e-1) + + +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None + +configs = [] +for mode in ['fwd']: + # for D_HEAD in [128]: + for causal in [False]: + configs.append(triton.testing.Benchmark( + x_names=['B', 'Mq','Mkv', 'Hq', 'Hkv', 'K'], + x_vals=get_input_shapes(), + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-d{128}-{mode}-causal={causal}', + args={ + # 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'causal': causal}) + ) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(B, Mq, Mkv, Hq, Hkv, K, causal, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 100 + rep = 400 + ms = 0 + if provider == "triton": + q = torch.randn( + [B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=False + ) + k = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=False + ).expand(-1, -1, -1, Hq // Hkv, -1) + v = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=False + ).expand(-1, -1, -1, Hq // Hkv, -1) + + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + + flops_per_matmul = 2 * B * Hq * (Mq * K * Mkv + Mq * Mkv * K) + total_flops = 2 * flops_per_matmul + totalBytes = ((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2 + + #return totalBytes / ms * 1e-9 + return ms + + +def main(): + bench_flash_attention.run(save_path='.', print_data=True) + +if __name__ == '__main__': + sys.exit(main()) \ No newline at end of file diff --git a/HybridTensor/triton/references/flash_attn_triton.py b/HybridTensor/triton/references/flash_attn_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..5af9fa54ff78f8abb3085d7b9738e14f202a1d79 --- /dev/null +++ b/HybridTensor/triton/references/flash_attn_triton.py @@ -0,0 +1,1162 @@ +""" +*Experimental* implementation of FlashAttention in Triton. +Tested with triton==2.0.0.dev20221202. +Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions +other than 64: +https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 +We'll update this implementation with the new Triton backend once this is fixed. + +We use the FlashAttention implementation from Phil Tillet a starting point. +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +Changes: +- Implement both causal and non-causal attention. +- Implement both self-attention and cross-attention. +- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. +- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. +- Support attention bias. +- Speed up the forward pass a bit, and only store the LSE instead of m and l. +- Make the backward for d=128 much faster by reducing register spilling. +- Optionally parallelize the backward pass across seqlen_k, to deal with the case of +small batch size * nheads. + +Caution: +- This is an *experimental* implementation. The forward pass should be quite robust but +I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). +- This implementation has only been tested on A100. +- If you plan to use headdim other than 64 and 128, you should test for race conditions +(due to the Triton compiler), as done in tests/test_flash_attn.py +"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions +for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident +that there are none left for other head dimensions. + +Differences between this Triton version and the CUDA version: +- Triton version doesn't support dropout. +- Triton forward is generally faster than CUDA forward, while Triton backward is +generally slower than CUDA backward. Overall Triton forward + backward is slightly slower +than CUDA forward + backward. +- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). +- Triton version supports attention bias, while CUDA version doesn't. +""" + +import math + +import torch +import triton +import triton.language as tl + + +# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), +# # This config has a race condition when EVEN_M == False, disabling it for now. +# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), +# ], +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# ) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Bias, + Out, + Lse, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_ob, + stride_oh, + stride_om, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # off_b = tl.program_id(1) + # off_h = tl.program_id(2) + # off_hb = off_b * nheads + off_h + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = ( + Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + ) + k_ptrs = ( + K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + ) + v_ptrs = ( + V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + ) + if BIAS_TYPE == "vector": + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = ( + Bias + + off_b * stride_bb + + off_h * stride_bh + + (offs_m[:, None] * stride_bm + offs_n[None, :]) + ) + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 + ) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # qk += tl.dot(q, k, trans_b=True) + qk += tl.dot(q, tl.trans(k)) + + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if BIAS_TYPE != "none": + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 + ).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler + # can then fuse the mult and add into an fma instruction. But if we have bias we need to + # to multiply with softmax_scale here. + qk = qk * softmax_scale + bias + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # # -- update output accumulator -- + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = ( + Out + + off_b * stride_ob + + off_h * stride_oh + + (offs_m[:, None] * stride_om + offs_d[None, :]) + ) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store( + out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) + ) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, + DO, + Delta, + stride_ob, + stride_oh, + stride_om, + stride_dob, + stride_doh, + stride_dom, + nheads, + seqlen_q, + seqlen_q_rounded, + headdim, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load( + Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + do = tl.load( + DO + + off_b * stride_dob + + off_h * stride_doh + + offs_m[:, None] * stride_dom + + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD: tl.constexpr, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + if BIAS_TYPE == "vector": + b_ptrs = Bias + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop + # may have zero step, and pipelining with the bias matrix could screw it up. + # So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + return + # k and v stay in SRAM throughout + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load( + k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 + ) + v = tl.load( + v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 + ) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + # recompute p = softmax(qk, dim=-1).T + qk = tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) + if IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + if BIAS_TYPE != "none": + tl.debug_barrier() # Race condition otherwise + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load( + b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + qk = qk * softmax_scale + bias + # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. + # Also wrong for headdim=64. + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + lse_i = tl.load(LSE + offs_m_curr) + if BIAS_TYPE == "none": + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + else: + p = tl.exp(qk - lse_i[:, None]) + # compute dv + # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs + # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, + # the output is correct. + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load( + do_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + # if EVEN_M: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs) + # else: + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + # else: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + # else: + # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + # & (offs_d[None, :] < headdim), other=0.0) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + dp = tl.dot(do, v, trans_b=True) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() + # compute ds = p * (dp - delta[:, None]) + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(ds, q, trans_a=True) + # compute dq + if not ( + EVEN_M & EVEN_HEADDIM + ): # Otherewise there's a race condition when BIAS_TYPE='matrix' + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load( + dq_ptrs, + mask=offs_m_curr[:, None] < seqlen_q, + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last", + ) + else: + dq = tl.load( + dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last", + ) + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + ) + # increment pointers + dq_ptrs += BLOCK_M * stride_dqm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + if BIAS_TYPE == "matrix": + b_ptrs += BLOCK_M * stride_bm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_dob, + stride_doh, + stride_dom, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dvb, + stride_dvh, + stride_dvn, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + if BIAS_TYPE != "none": + Bias += off_b * stride_bb + off_h * stride_bh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=False, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=True, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + +def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, "FlashAttention only support head dimensions up to 128" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + if bias.stride(-1) != 1: + bias = bias.contiguous() + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError( + "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" + ) + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, + k, + v, + bias, + o, + lse, + tmp, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + o.stride(0), + o.stride(2), + o.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o, lse, softmax_scale # softmax_scale could have been updated + + +def _flash_attn_backward( + do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None +): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + nheads, + seqlen_q, + seqlen_q_rounded, + d, + BLOCK_M=128, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError( + "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" + ) + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: ( + triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads, + ) + _bwd_kernel[grid]( + q, + k, + v, + bias, + do, + dq_accum, + dk, + dv, + lse, + delta, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + do.stride(0), + do.stride(2), + do.stride(1), + dq_accum.stride(0), + dq_accum.stride(2), + dq_accum.stride(1), + dk.stride(0), + dk.stride(2), + dk.stride(1), + dv.stride(0), + dv.stride(2), + dv.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): + """ + qkv: (batch, seqlen, 3, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). + ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) + """ + # Make sure that the last dimension is contiguous + if qkv.stride(-1) != 1: + qkv = qkv.contiguous() + o, lse, ctx.softmax_scale = _flash_attn_forward( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + bias=bias, + causal=causal, + softmax_scale=softmax_scale, + ) + ctx.save_for_backward(qkv, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + qkv, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet" + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dqkv = torch.empty_like(qkv) + _flash_attn_backward( + do, + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + o, + lse, + dqkv[:, :, 0], + dqkv[:, :, 1], + dqkv[:, :, 2], + bias=bias, + causal=ctx.causal, + softmax_scale=ctx.softmax_scale, + ) + return dqkv, None, None, None + + +flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply + + +class FlashAttnKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): + """ + q: (batch, seqlen_q, nheads, headdim) + kv: (batch, seqlen_k, 2, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). + ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) + """ + # Make sure that the last dimension is contiguous + q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, kv, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, kv, o, lse, bias = ctx.saved_tensors + if len(ctx.needs_input_grad) >= 3: + assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet" + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + _flash_attn_backward( + do, + q, + kv[:, :, 0], + kv[:, :, 1], + o, + lse, + dq, + dkv[:, :, 0], + dkv[:, :, 1], + bias=bias, + causal=ctx.causal, + softmax_scale=ctx.softmax_scale, + ) + return dq, dkv, None, None, None + + +flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply + + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): + """ + q: (batch_size, seqlen_q, nheads, headdim) + k, v: (batch_size, seqlen_k, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). + ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) + """ + # Make sure that the last dimension is contiguous + q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, k, v, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet" + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + _flash_attn_backward( + do, + q, + k, + v, + o, + lse, + dq, + dk, + dv, + bias=bias, + causal=ctx.causal, + softmax_scale=ctx.softmax_scale, + ) + return dq, dk, dv, None, None, None + + +flash_attn_func = FlashAttnFunc.apply diff --git a/HybridTensor/triton/references/fused_attn.py b/HybridTensor/triton/references/fused_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..4b01c70ef0a33a22d544e54e7fa4e5aa9c0ad511 --- /dev/null +++ b/HybridTensor/triton/references/fused_attn.py @@ -0,0 +1,641 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm (https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team + +Extra Credits: +- Original flash attention paper (https://arxiv.org/abs/2205.14135) + +""" + +import pytest +import torch + +import triton +import triton.language as tl + + +def is_hip(): + # return triton.runtime.driver.active.get_current_target().backend == "hip" + return False + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, # + K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr, fp8_v: tl.constexpr): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + qk = tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl.load(V_block_ptr) + if fp8_v: + p = p.to(tl.float8e5) + else: + p = p.to(tl.float16) + acc = tl.dot(p, v, acc) + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +# We don't run auto-tuning every time to keep the tutorial fast. Keeping +# the code below and commenting out the equivalent parameters is convenient for +# re-tuning. +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ + for BM in [64, 128]\ + for BN in [32, 64]\ + for s in ([1] if is_hip() else [3, 4, 7])\ + for w in [4, 8]\ +] + + +def keep(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: + return False + return True + + +@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) +@triton.jit +def _attn_fwd(Q, K, V, sm_scale, M, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr # + ): + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H # batch index + off_h = off_hz % H # head index + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + # block pointers + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=v_order, + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(HEAD_DIM, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_bwd_preprocess(O, DO, # + Delta, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # + ): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_hz = tl.program_id(1) + off_n = tl.arange(0, HEAD_DIM) + # load + o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) + do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hz * N_CTX + off_m, delta) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + # Filled in by the wrapper. + start_n, start_m, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, HEAD_DIM) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq(dq, q, K, V, # + do, m, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, # + DO, # + DQ, DK, DV, # + M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, HEAD_DIM) + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=True # + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv( # + dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=False # + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # + MASK=True # + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * BLOCK_N2, num_steps, # + MASK=False # + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale): + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + # Tuning for AMD target + if is_hip(): + waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 + extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} + + grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + _attn_fwd[grid]( + q, k, v, sm_scale, M, o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + **extra_kern_args) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o, do, # + delta, # + BATCH, N_HEAD, N_CTX, # + BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # + ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # + M, delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + N_HEAD, N_CTX, # + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES # + ) + + return dq, dk, dv, None, None + + +attention = _attention.apply + + +@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)]) +@pytest.mark.parametrize("causal", [True]) +def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): + torch.manual_seed(20) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # triton implementation + tri_out = attention(q, k, v, causal, sm_scale).half() + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + rtol = 0.0 + # Relative tolerance workaround for known hardware limitation of MI200 GPU. + # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": + rtol = 1e-2 + assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) + assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) + assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) + + +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') +BATCH, N_HEADS, HEAD_DIM = 1, 64, 128 +# vary seq length for fixed head and batch=4 +configs = [] +for mode in ["fwd", "bwd"]: + for causal in [True, False]: + if mode == "bwd" and not causal: + continue + configs.append( + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(10, 15)], + line_arg="provider", + line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + + (["flash"] if HAS_FLASH else []), + line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="TFLOPS", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "HEAD_DIM": HEAD_DIM, + "mode": mode, + "causal": causal, + }, + )) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): + assert mode in ["fwd", "bwd"] + warmup = 25 + rep = 100 + dtype = torch.float16 + if "triton" in provider: + q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + if mode == "fwd" and "fp8" in provider: + q = q.to(torch.float8_e5m2) + k = k.to(torch.float8_e5m2) + v = v.permute(0, 1, 3, 2).contiguous() + v = v.permute(0, 1, 3, 2) + v = v.to(torch.float8_e5m2) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, causal, sm_scale) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, causal=causal) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + return total_flops / ms * 1e-9 + + +if __name__ == "__main__": + # only works on post-Ampere GPUs right now + bench_flash_attention.run(save_path=".", print_data=True) diff --git a/HybridTensor/triton/references/gemm.py b/HybridTensor/triton/references/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..b4731fab49acfb0b8050f5e4778a448e6a0b1442 --- /dev/null +++ b/HybridTensor/triton/references/gemm.py @@ -0,0 +1,268 @@ +import torch + +import triton +import triton.language as tl +from HybridTensor.triton.triton_utils import get_autotune_config, benchmark_fwd +from HybridTensor.triton.activations import leaky_relu, relu + +from HybridTensor.utils.utils import sparse_index, arg_parser + + +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, BIAS, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr, # + HAS_BIAS: tl.constexpr # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + if HAS_BIAS: + # Load the bias vector + bias = tl.load(BIAS + offs_bn, mask=offs_bn < N, other=0.0) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if HAS_BIAS: + # add the bias it to the accumulator + # accumulator += bias + accumulator += bias[None, :] # Broadcast bias over rows + + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def triton_matmul(a, b, bias = None, activation=""): + # Check constraints. + # a and b are expected to be in row-major format. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + has_bias = bias is not None + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, b, c, bias, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation, # + HAS_BIAS = has_bias + ) + return c + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'N', 'K'], +) + +@triton.jit +def matmul_row_col( + # Pointers to matrices + # A is in Row Major, B is in Column Major + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # Strides + stride_am, stride_ak, # + stride_bk, stride_bn, # Note: B is now column-major + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K) in row-major and B has shape (K, N) in column-major. + C has shape (M, N). + """ + # ----------------------------------------------------------- + # Map program ids to block of C + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ----------------------------------------------------------- + # Create pointers for the first blocks of A and B + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Pointer arithmetic for A (row-major, K contiguous) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + + # Pointer arithmetic for B (column-major, K contiguous) + # The stride_bn is now used for K (since B is column-major) + # and stride_bk is used for N. + b_ptrs = b_ptr + (offs_bn[None, :] * stride_bk + offs_k[:, None] * stride_bn) + + # ----------------------------------------------------------- + # Initialize the accumulator for C in FP32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load blocks of A and B with masking for out-of-bounds K + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # Perform dot product and accumulate + accumulator = tl.dot(a, b, accumulator) + # Advance pointers for next iteration over K + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bn # Note: stride_bn moves along the K dimension now + + # Optional activation + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + elif ACTIVATION == "relu": + accumulator = relu(accumulator) + + # Convert the accumulator back to FP16 + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the result into matrix C + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul_rowcol(a, b, activation=""): + # Check constraints. + # b is expected to be in column major format. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + + M, K = a.shape + N, K = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_row_col[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION='' # + ) + + return c + + +if __name__ == '__main__': + args = arg_parser() + A = torch.randn(args.batch_size, args.in_features, device='cuda', dtype=torch.float16) + B = torch.randn(args.in_features, args.hidden_features, device='cuda', dtype=torch.float16) + if args.bias: + print("Using bias") + bias = torch.randn(args.hidden_features, device='cuda', dtype=torch.float16) + else: + print("Not using bias") + bias = None + + out_b, cublass_time = benchmark_fwd(A, B, bias = None, function = torch.matmul, print_result = args.print_results) + out_a_gather, gather_gemm_time = benchmark_fwd(A, B, bias = bias, function=triton_matmul, print_result = args.print_results) + # out_a_gather_torch,torch_sel_gemm_time = benchmark_fwd(A, B, index_vec=index_vec, function=torch_sel_gemm_col, print_result = args.print_results) + + speedup = cublass_time / gather_gemm_time + + # check results + if args.check_results: + print("Checking results") + print("Kernel gather output: ", out_a_gather) + print("Torch gather output: ", out_b + bias) + + print(f"Speedup: {speedup:.2f}") + \ No newline at end of file diff --git a/HybridTensor/triton/references/select_attn.py b/HybridTensor/triton/references/select_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..51ce29dc1147c94f873893a0f37b8e0baaf81fa3 --- /dev/null +++ b/HybridTensor/triton/references/select_attn.py @@ -0,0 +1,636 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple +import pytest +import torch +import sys + +import triton +import triton.language as tl + +def _strides(x: torch.Tensor, *stride_names: str): + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +@triton.jit +def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block. In case of int4-quantized K/V, + # dequantize them after loading. If quantization is group-wise, + # use group_id to advance the pointers to the current group. + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()) + + if PACKED_PER_VAL > 1: + # K/V are quantized, load quantization coefficients and dequantize + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load( + K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () + ) + v_scale_shift = tl.load( + V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () + ) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = ( + x_[:, None, :] >> offsets[None, :, None] + ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Seq_len, + batch_head_index, # {[B0, H1], [B1, H1], ...}, -> [N_selected, 2]; new argument ## if GQA, we select from groups: {[B0, G1], [B1, G1], ...}, -> [N_selected, 2] + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + Z, + N_CTX_Q, + N_CTX_K, + BLOCK_N_PER_SPLIT, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr = 1, + N_GROUPS: tl.constexpr = 1, +): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + """ + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), + f"Only 4-bit quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + + QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + splitk_idx = tl.program_id(2) + + # off_z = off_zhg // (H * G) # batch + # off_h = (off_zhg // G) % H # head + off_g = off_zhg % G # group, when G == 1, off_g = 0 + off_idx = off_zhg // G # when G == 1, off_idx = off_zhg + + # Compute the pointer to the batch_head_index + batch_head_index_ptr = batch_head_index + off_idx * 2 + + # Load off_z and off_h separately + off_z = tl.load(batch_head_index_ptr + 0, mask=True, other=0).to(tl.int32) + off_h = tl.load(batch_head_index_ptr + 1, mask=True, other=0).to(tl.int32) + + # if we want to update kv cache inside the kernel, load the new key and value from the arguments (pointers, needs to be added) and update the cache in K and V by writing to seq_len position in the cache + # this will only update the kv cache for the selected heads, the cache will not be updated for all heads + # ... #TODO ? # + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z) + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg + # Additional shift by 1 along the last dimension in the quantized case, since + # the first element along that dim contains packed quantization coefficients. + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * QUANTIZED * N_GROUPS, + shape=(hi, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + + if QUANTIZED: + # Pointers to quantization coefficients + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, 1), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0,)) + q = (q * qk_scale).to(q.dtype) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + Q.dtype.element_ty, + 0, + ) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # noqa: F821 + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p, v) + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + if PACKED_PER_VAL > 1: + K_scale_shift_block_ptr = tl.advance( + K_scale_shift_block_ptr, (0, BLOCK_N) + ) + V_scale_shift_block_ptr = tl.advance( + V_scale_shift_block_ptr, (BLOCK_N, 0) + ) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0,), + ) + # Write metadata for split-K reduction + Metadata_ptr = ( + Metadata + + off_zhg * stride_mzhg + + splitk_idx * stride_ms + + start_m * BLOCK_M + + tl.arange(0, BLOCK_M) + ) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + batch_head_index, # {[B0, H1], [B1, H1], ...}, -> [N_selected, 2]; new argument ## if GQA, we select from groups: {[B0, G1], [B1, G1], ...}, -> [N_selected, 2] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil:tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k:tl.constexpr, + splitK_pow2:tl.constexpr, + use_mask:tl.constexpr, +): + off_zhg = tl.program_id(0) + # off_z = off_zhg // (H * G) + # off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + off_k = tl.program_id(2) + + off_idx = off_zhg // G # when G == 1, off_idx = off_zhg + off_g = off_zhg % G # when G == 1, off_g = 0 + + # Compute the pointer to the batch_head_index + batch_head_index_ptr = batch_head_index + off_idx * 2 + + # Load off_z and off_h separately + off_z = tl.load(batch_head_index_ptr + 0, mask=True, other=0).to(tl.int32) + off_h = tl.load(batch_head_index_ptr + 1, mask=True, other=0).to(tl.int32) + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = ( + Metadata + + stride_mzhg * off_zhg + + spk_idx * stride_ms + + off_m * stride_mm + ) + + o_ptr = ( + Out_splitK + + off_zhg * stride_osk_zhg + + stride_osk_m * off_m + + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + + kidx[None, :] * stride_osk_k + ) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:,None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + acc_out = tl.sum(acc, axis=0) / g_sum + Out_ptr = ( + Out + + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + off_k * BLOCK_SIZE + + tl.arange(0, BLOCK_SIZE) + ) + tl.store(Out_ptr, acc_out) + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + +class select_attention(torch.autograd.Function): + + OPERATOR = _fwd_kernel_splitK + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } + SUPPORTED_MAX_K = 128 + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "triton_splitKF" + + @staticmethod + def forward(cls, q, k, v, scale_float, batch_head_index, seq_len=None): + + cls.SPLIT_K: Optional[int] = None + cls.BLOCK_M = 16 + cls.BLOCK_N = 64 + + cls.NUM_GROUPS = 1 # Default quantization is row-wise + + # attn_bias = inp.attn_bias + # seq_len = None + + # Transpose in the case of MQA/GQA + mqa_swap_seqlen_head = False + if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + mqa_swap_seqlen_head = True + assert q.shape[1] == 1 + q = q.transpose(1, 3) + k = k[:, :, :, :1] + v = v[:, :, :, :1] + + if k.dtype == torch.int32: + # Quantized K/V + PACKED_PER_VAL = 8 + Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8 + else: + Lk = k.shape[-1] + PACKED_PER_VAL = 1 + + B, Mk_original, G, H, Kkv = k.shape + B, M, G, H, Kq = q.shape + assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" + # print(f"B = {B}, M = {M}, G = {G}, H = {H}, Kkv = {Kkv}, Kq = {Kq}") + + # Ensure batch_head_index is a torch tensor of dtype torch.int32 + batch_head_index = batch_head_index.to(torch.int32) + N_selected = batch_head_index.shape[0] + + # Determine if seq_len is an integer or tensor + if isinstance(seq_len, int): + Mk = seq_len # Fixed sequence length for all batches + use_seq_len = False + elif seq_len is not None: # seq_len is a tensor + Mk = Mk_original # Use original Mk from k.shape[1] + use_seq_len = True + else: # seq_len is None, use Mk_original for all batches + Mk = Mk_original + use_seq_len = False + + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = get_split_k(B, G, H, Mk) + + M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M + o_splitk = torch.zeros( + [B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device + ) + metadata = torch.empty( + [B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device + ) + lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + # grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k) + # grid = (triton.cdiv(M, BLOCK_M), N_selected * G, split_k) + grid = (triton.cdiv(M, BLOCK_M), N_selected, split_k) # TODO: fix this for GQA + + num_warps = 1 + split_size = (Mk + split_k - 1) // split_k + + # use_seq_len = seq_len is not None + + # print(f"B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}, num_of_wgs = {G * G * H * split_k}") + + _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=scale_float, + Out_splitK=o_splitk, + Metadata=metadata, + Seq_len=seq_len, + batch_head_index=batch_head_index, + **_strides(q, "qz", "qm", "qg", "qh", "qk"), + **_strides(k, "kz", "kn", "kg", "kh", "kk"), + **_strides(v, "vz", "vn", "vg", "vh", "vk"), + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + Z=B, + H=H, + G=G, + N_CTX_Q=M, + N_CTX_K=Mk, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len, + USE_SEQ_LEN=use_seq_len, + num_warps=num_warps, + num_stages=1, + PACKED_PER_VAL=PACKED_PER_VAL, + N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, + ) + + if mqa_swap_seqlen_head: + out = torch.zeros( + (B, H, G, M, Kq), device=q.device, dtype=q.dtype + ).transpose(1, 3) + else: + out = torch.zeros((B, M, G, H, Kq), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if B * G * H * M >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert out.shape[-1] % k_block_num == 0 + k_block_size = out.shape[-1] // k_block_num + # grid = (B * G * H, M, k_block_num) + grid = (N_selected, M, k_block_num) # TODO: fix this for GQA + + _splitK_reduce[grid]( + o_splitk, + metadata, + out, + lse, + batch_head_index=batch_head_index, + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=M_ceil, + BLOCK_SIZE=k_block_size, + G=G, + H=H, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + num_warps=4 + ) + + lse = lse.reshape([B, G, H, M]) + if mqa_swap_seqlen_head: + # H/M dimensions have been swapped + out = out.transpose(1, 3) + lse = lse.transpose(2, 3) + if q.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if Mk == 0: + out.zero_() + if mqa_swap_seqlen_head: + out = out.reshape(B, -1, M * G, Kq).transpose(1, 2).contiguous() + else: + out = out.reshape(B, H * G, -1, Kq).contiguous() + + return out + +select_attn = select_attention.apply \ No newline at end of file diff --git a/HybridTensor/triton/references/select_attn_64b_kernel.py b/HybridTensor/triton/references/select_attn_64b_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..9882324a102a2445ff5b6a0e8592bf2e38f854d5 --- /dev/null +++ b/HybridTensor/triton/references/select_attn_64b_kernel.py @@ -0,0 +1,644 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple +import pytest +import torch +import sys + +import triton +import triton.language as tl + +def _strides(x: torch.Tensor, *stride_names: str): + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +@triton.jit +def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block. In case of int4-quantized K/V, + # dequantize them after loading. If quantization is group-wise, + # use group_id to advance the pointers to the current group. + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()) + + if PACKED_PER_VAL > 1: + # K/V are quantized, load quantization coefficients and dequantize + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load( + K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () + ) + v_scale_shift = tl.load( + V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () + ) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = ( + x_[:, None, :] >> offsets[None, :, None] + ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Seq_len, + batch_head_index, # {[B0, H1], [B1, H1], ...}, -> [N_selected, 2]; new argument ## if GQA, we select from groups: {[B0, G1], [B1, G1], ...}, -> [N_selected, 2] + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + Z, + N_CTX_Q, + N_CTX_K, + BLOCK_N_PER_SPLIT, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr = 1, + N_GROUPS: tl.constexpr = 1, +): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + """ + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), + f"Only 4-bit quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + + QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + # experimental + # Cast strides and indexing variables to tl.int64 to prevent integer overflows + + start_m = tl.program_id(0).to(tl.int32) + off_zhg = tl.program_id(1).to(tl.int64) + splitk_idx = tl.program_id(2).to(tl.int64) + + + # off_z = off_zhg // (H * G) # batch + # off_h = (off_zhg // G) % H # head + off_g = off_zhg % G # group, when G == 1, off_g = 0 + off_idx = off_zhg // G # when G == 1, off_idx = off_zhg + + # Compute the pointer to the batch_head_index + batch_head_index_ptr = batch_head_index + off_idx * 2 + + # Load off_z and off_h separately + off_z = tl.load(batch_head_index_ptr + 0, mask=True, other=0).to(tl.int64) + off_h = tl.load(batch_head_index_ptr + 1, mask=True, other=0).to(tl.int64) + + # if we want to update kv cache inside the kernel, load the new key and value from the arguments (pointers, needs to be added) and update the cache in K and V by writing to seq_len position in the cache + # this will only update the kv cache for the selected heads, the cache will not be updated for all heads + # ... #TODO ? # + + # Compute lo and hi + lo = (splitk_idx * BLOCK_N_PER_SPLIT).to(tl.int32) + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z).to(tl.int32) + else: + kv_len = N_CTX_K.to(tl.int32) + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len).to(tl.int32) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg + # Additional shift by 1 along the last dimension in the quantized case, since + # the first element along that dim contains packed quantization coefficients. + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * QUANTIZED * N_GROUPS, + shape=(hi, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + + if QUANTIZED: + # Pointers to quantization coefficients + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, 1), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0,)) + q = (q * qk_scale).to(q.dtype) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + Q.dtype.element_ty, + 0, + ) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # noqa: F821 + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p, v) + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + if PACKED_PER_VAL > 1: + K_scale_shift_block_ptr = tl.advance( + K_scale_shift_block_ptr, (0, BLOCK_N) + ) + V_scale_shift_block_ptr = tl.advance( + V_scale_shift_block_ptr, (BLOCK_N, 0) + ) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0,), + ) + # Write metadata for split-K reduction + Metadata_ptr = ( + Metadata + + off_zhg * stride_mzhg + + splitk_idx * stride_ms + + start_m * BLOCK_M + + tl.arange(0, BLOCK_M) + ) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + batch_head_index, # {[B0, H1], [B1, H1], ...}, -> [N_selected, 2]; new argument ## if GQA, we select from groups: {[B0, G1], [B1, G1], ...}, -> [N_selected, 2] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil:tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k:tl.constexpr, + splitK_pow2:tl.constexpr, + use_mask:tl.constexpr, +): + + off_zhg = tl.program_id(0).to(tl.int64) + off_m = tl.program_id(1).to(tl.int64) + off_k = tl.program_id(2).to(tl.int64) + + off_idx = off_zhg // G # when G == 1, off_idx = off_zhg + off_g = off_zhg % G # when G == 1, off_g = 0 + + # Compute the pointer to the batch_head_index + batch_head_index_ptr = batch_head_index + off_idx * 2 + + # Load off_z and off_h separately + off_z = tl.load(batch_head_index_ptr + 0, mask=True, other=0).to(tl.int64) + off_h = tl.load(batch_head_index_ptr + 1, mask=True, other=0).to(tl.int64) + + # Read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = ( + Metadata + + stride_mzhg * off_zhg + + spk_idx * stride_ms + + off_m * stride_mm + ) + + o_ptr = ( + Out_splitK + + off_zhg * stride_osk_zhg + + stride_osk_m * off_m + + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + + kidx[None, :] * stride_osk_k + ) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:,None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + acc_out = tl.sum(acc, axis=0) / g_sum + Out_ptr = ( + Out + + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + off_k * BLOCK_SIZE + + tl.arange(0, BLOCK_SIZE) + ) + tl.store(Out_ptr, acc_out) + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + +class select_attention(torch.autograd.Function): + + OPERATOR = _fwd_kernel_splitK + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } + SUPPORTED_MAX_K = 128 + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "triton_splitKF" + + @staticmethod + def forward(cls, q, k, v, scale_float, batch_head_index, seq_len=None): + + cls.SPLIT_K: Optional[int] = None + cls.BLOCK_M = 16 + cls.BLOCK_N = 64 + + cls.NUM_GROUPS = 1 # Default quantization is row-wise + + # attn_bias = inp.attn_bias + # seq_len = None + + # Transpose in the case of MQA/GQA + mqa_swap_seqlen_head = False + if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + mqa_swap_seqlen_head = True + assert q.shape[1] == 1 + q = q.transpose(1, 3) + k = k[:, :, :, :1] + v = v[:, :, :, :1] + + if k.dtype == torch.int32: + # Quantized K/V + PACKED_PER_VAL = 8 + Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8 + else: + Lk = k.shape[-1] + PACKED_PER_VAL = 1 + + B, Mk_original, G, H, Kkv = k.shape + B, M, G, H, Kq = q.shape + assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" + # print(f"B = {B}, M = {M}, G = {G}, H = {H}, Kkv = {Kkv}, Kq = {Kq}") + + # Ensure batch_head_index is a torch tensor of dtype torch.int32 + batch_head_index = batch_head_index.to(q.device) + batch_head_index = batch_head_index.to(torch.int32) + N_selected = batch_head_index.shape[0] + + if seq_len is not None and isinstance(seq_len, torch.Tensor): + seq_len = seq_len.to(q.device) + + + # Determine if seq_len is an integer or tensor + if isinstance(seq_len, int): + Mk = seq_len # Fixed sequence length for all batches + use_seq_len = False + elif seq_len is not None: # seq_len is a tensor + Mk = Mk_original # Use original Mk from k.shape[1] + use_seq_len = True + else: # seq_len is None, use Mk_original for all batches + Mk = Mk_original + use_seq_len = False + + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = get_split_k(B, G, H, Mk) + + M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M + o_splitk = torch.zeros( + [B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device + ) + metadata = torch.empty( + [B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device + ) + lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + # grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k) + # grid = (triton.cdiv(M, BLOCK_M), N_selected * G, split_k) + grid = (triton.cdiv(M, BLOCK_M), N_selected, split_k) # TODO: fix this for GQA + + num_warps = 1 + split_size = (Mk + split_k - 1) // split_k + + # use_seq_len = seq_len is not None + + # print(f"B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}, num_of_wgs = {G * G * H * split_k}") + + _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=scale_float, + Out_splitK=o_splitk, + Metadata=metadata, + Seq_len=seq_len, + batch_head_index=batch_head_index, + **_strides(q, "qz", "qm", "qg", "qh", "qk"), + **_strides(k, "kz", "kn", "kg", "kh", "kk"), + **_strides(v, "vz", "vn", "vg", "vh", "vk"), + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + Z=B, + H=H, + G=G, + N_CTX_Q=M, + N_CTX_K=Mk, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len, + USE_SEQ_LEN=use_seq_len, + num_warps=num_warps, + num_stages=1, + PACKED_PER_VAL=PACKED_PER_VAL, + N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, + ) + + if mqa_swap_seqlen_head: + out = torch.zeros( + (B, H, G, M, Kq), device=q.device, dtype=q.dtype + ).transpose(1, 3) + else: + out = torch.zeros((B, M, G, H, Kq), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if B * G * H * M >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert out.shape[-1] % k_block_num == 0 + k_block_size = out.shape[-1] // k_block_num + # grid = (B * G * H, M, k_block_num) + grid = (N_selected, M, k_block_num) # TODO: fix this for GQA + + _splitK_reduce[grid]( + o_splitk, + metadata, + out, + lse, + batch_head_index=batch_head_index, + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=M_ceil, + BLOCK_SIZE=k_block_size, + G=G, + H=H, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + num_warps=4 + ) + + lse = lse.reshape([B, G, H, M]) + if mqa_swap_seqlen_head: + # H/M dimensions have been swapped + out = out.transpose(1, 3) + lse = lse.transpose(2, 3) + if q.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if Mk == 0: + out.zero_() + if mqa_swap_seqlen_head: + out = out.reshape(B, -1, M * G, Kq).transpose(1, 2).contiguous() + else: + out = out.reshape(B, H * G, -1, Kq).contiguous() + + return out + +select_attn = select_attention.apply \ No newline at end of file diff --git a/HybridTensor/triton/select_attn_v1.py b/HybridTensor/triton/select_attn_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..5eecbe357ae2e871d175ac60e81264c993a63e1a --- /dev/null +++ b/HybridTensor/triton/select_attn_v1.py @@ -0,0 +1,656 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple +import pytest +import torch +import sys + +import triton +import triton.language as tl + +def _strides(x: torch.Tensor, *stride_names: str): + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +@triton.jit +def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block. In case of int4-quantized K/V, + # dequantize them after loading. If quantization is group-wise, + # use group_id to advance the pointers to the current group. + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()) + + if PACKED_PER_VAL > 1: + # K/V are quantized, load quantization coefficients and dequantize + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load( + K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () + ) + v_scale_shift = tl.load( + V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () + ) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = ( + x_[:, None, :] >> offsets[None, :, None] + ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Seq_len, + batch_head_index, # [B, top_k] + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + Z, + N_CTX_Q, + N_CTX_K, + BLOCK_N_PER_SPLIT, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr = 1, + N_GROUPS: tl.constexpr = 1, + top_k: tl.constexpr = 1, +): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + """ + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), + f"Only 4-bit quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + + QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + # experimental + # Cast strides and indexing variables to tl.int64 to prevent integer overflows + + start_m = tl.program_id(0).to(tl.int32) + off_zhg = tl.program_id(1).to(tl.int64) + splitk_idx = tl.program_id(2).to(tl.int64) + + off_g = off_zhg % G # group, when G == 1, off_g = 0 + off_btk = off_zhg // G # when G == 1, off_btk = off_zhg + + # batch_idx = off_btk // top_k + idx_in_topk = off_btk % top_k + off_z = off_btk // top_k # batch index + + # Load off_h (head index) from batch_head_index + batch_head_index_ptr = batch_head_index + off_z * top_k + idx_in_topk + off_h = tl.load(batch_head_index_ptr, mask=True, other=0).to(tl.int64) + + # if we want to update kv cache inside the kernel, load the new key and value from the arguments (pointers, needs to be added) and update the cache in K and V by writing to seq_len position in the cache + # this will only update the kv cache for the selected heads, the cache will not be updated for all heads + # ... #TODO ? # + + # Compute lo and hi + lo = (splitk_idx * BLOCK_N_PER_SPLIT).to(tl.int32) + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z).to(tl.int32) + else: + kv_len = N_CTX_K.to(tl.int32) + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len).to(tl.int32) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg + # Additional shift by 1 along the last dimension in the quantized case, since + # the first element along that dim contains packed quantization coefficients. + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * QUANTIZED * N_GROUPS, + shape=(hi, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + + if QUANTIZED: + # Pointers to quantization coefficients + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, 1), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0,)) + q = (q * qk_scale).to(q.dtype) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + Q.dtype.element_ty, + 0, + ) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # noqa: F821 + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p, v) + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + if PACKED_PER_VAL > 1: + K_scale_shift_block_ptr = tl.advance( + K_scale_shift_block_ptr, (0, BLOCK_N) + ) + V_scale_shift_block_ptr = tl.advance( + V_scale_shift_block_ptr, (BLOCK_N, 0) + ) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0,), + ) + # Write metadata for split-K reduction + Metadata_ptr = ( + Metadata + + off_zhg * stride_mzhg + + splitk_idx * stride_ms + + start_m * BLOCK_M + + tl.arange(0, BLOCK_M) + ) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + batch_head_index, # [B, top_k] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil:tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k:tl.constexpr, + splitK_pow2:tl.constexpr, + use_mask:tl.constexpr, + top_k:tl.constexpr, +): + + off_zhg = tl.program_id(0).to(tl.int64) + off_m = tl.program_id(1).to(tl.int64) + off_k = tl.program_id(2).to(tl.int64) + + off_g = off_zhg % G + off_btk = off_zhg // G + + idx_in_topk = off_btk % top_k + off_z = off_btk // top_k + + # Load off_h (head index) from batch_head_index + batch_head_index_ptr = batch_head_index + off_z * top_k + idx_in_topk + off_h = tl.load(batch_head_index_ptr, mask=True, other=0).to(tl.int64) + + # Read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = ( + Metadata + + stride_mzhg * off_zhg + + spk_idx * stride_ms + + off_m * stride_mm + ) + + o_ptr = ( + Out_splitK + + off_zhg * stride_osk_zhg + + stride_osk_m * off_m + + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + + kidx[None, :] * stride_osk_k + ) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:,None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + acc_out = tl.sum(acc, axis=0) / g_sum + Out_ptr = ( + Out + + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + off_k * BLOCK_SIZE + + tl.arange(0, BLOCK_SIZE) + ) + tl.store(Out_ptr, acc_out) + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + +class select_attention(torch.autograd.Function): + + OPERATOR = _fwd_kernel_splitK + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } + SUPPORTED_MAX_K = 128 + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "triton_splitKF" + + @staticmethod + def forward(cls, q, k, v, scale_float, batch_head_index, seq_len=None): + + cls.SPLIT_K: Optional[int] = None + cls.BLOCK_M = 16 + cls.BLOCK_N = 64 + + cls.NUM_GROUPS = 1 # Default quantization is row-wise + + # attn_bias = inp.attn_bias + # seq_len = None + + # Transpose in the case of MQA/GQA + mqa_swap_seqlen_head = False + if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + mqa_swap_seqlen_head = True + assert q.shape[1] == 1 + q = q.transpose(1, 3) + k = k[:, :, :, :1] + v = v[:, :, :, :1] + + if k.dtype == torch.int32: + # Quantized K/V + PACKED_PER_VAL = 8 + Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8 + else: + Lk = k.shape[-1] + PACKED_PER_VAL = 1 + + B, Mk_original, G, H, Kkv = k.shape + B, M, G, H, Kq = q.shape + assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" + # print(f"B = {B}, M = {M}, G = {G}, H = {H}, Kkv = {Kkv}, Kq = {Kq}") + + # Ensure batch_head_index is a torch tensor of dtype torch.int32 + # Adjust batch_head_index to the new shape and flatten it + batch_head_index = batch_head_index.to(q.device) + batch_head_index = batch_head_index.to(torch.int32) + # print(f"batch_head_index= {batch_head_index}") + # print(f"batch_head_index.shape = {batch_head_index.shape}") + B, top_k = batch_head_index.shape + # batch_head_index_flat = batch_head_index.flatten() # Shape (B * top_k,) # Is this CG safe? + N_selected = B * top_k + + if seq_len is not None and isinstance(seq_len, torch.Tensor): + seq_len = seq_len.to(q.device) + + # Determine if seq_len is an integer or tensor + if isinstance(seq_len, int): + Mk = seq_len # Fixed sequence length for all batches + use_seq_len = False + elif seq_len is not None: # seq_len is a tensor + Mk = Mk_original # Use original Mk from k.shape[1] + use_seq_len = True + else: # seq_len is None, use Mk_original for all batches + Mk = Mk_original + use_seq_len = False + + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = get_split_k(B, G, H, Mk) + + M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M + o_splitk = torch.zeros( + [B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device + ) + metadata = torch.empty( + [B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device + ) + lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + + # remove this if not needed + # o_splitk.zero_() + # metadata.zero_() + # lse.zero_() + + # grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k) + # grid = (triton.cdiv(M, BLOCK_M), N_selected * G, split_k) + grid = (triton.cdiv(M, BLOCK_M), N_selected, split_k) # TODO: fix this for GQA + + num_warps = 1 + split_size = (Mk + split_k - 1) // split_k + + # use_seq_len = seq_len is not None + + # print(f"B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}, num_of_wgs = {G * G * H * split_k}") + + _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=scale_float, + Out_splitK=o_splitk, + Metadata=metadata, + Seq_len=seq_len, + batch_head_index=batch_head_index, + **_strides(q, "qz", "qm", "qg", "qh", "qk"), + **_strides(k, "kz", "kn", "kg", "kh", "kk"), + **_strides(v, "vz", "vn", "vg", "vh", "vk"), + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + Z=B, + H=H, + G=G, + N_CTX_Q=M, + N_CTX_K=Mk, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len, + USE_SEQ_LEN=use_seq_len, + num_warps=num_warps, + num_stages=1, + PACKED_PER_VAL=PACKED_PER_VAL, + N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, + top_k=top_k, + ) + + if mqa_swap_seqlen_head: + out = torch.zeros( + (B, H, G, M, Kq), device=q.device, dtype=q.dtype + ).transpose(1, 3) + else: + out = torch.zeros((B, M, G, H, Kq), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if B * G * H * M >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert out.shape[-1] % k_block_num == 0 + k_block_size = out.shape[-1] // k_block_num + # grid = (B * G * H, M, k_block_num) + grid = (N_selected, M, k_block_num) # TODO: fix this for GQA + + _splitK_reduce[grid]( + o_splitk, + metadata, + out, + lse, + batch_head_index=batch_head_index, + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=M_ceil, + BLOCK_SIZE=k_block_size, + G=G, + H=H, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + num_warps=4, + top_k=top_k, + ) + + lse = lse.reshape([B, G, H, M]) + if mqa_swap_seqlen_head: + # H/M dimensions have been swapped + out = out.transpose(1, 3) + lse = lse.transpose(2, 3) + if q.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if Mk == 0: + out.zero_() + if mqa_swap_seqlen_head: + out = out.reshape(B, -1, M * G, Kq).transpose(1, 2).contiguous() + else: + out = out.reshape(B, H * G, -1, Kq).contiguous() + + return out + +select_attn = select_attention.apply \ No newline at end of file diff --git a/HybridTensor/triton/select_gqa.py b/HybridTensor/triton/select_gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..56fd308bed7124631651345cde465848cd6ebecc --- /dev/null +++ b/HybridTensor/triton/select_gqa.py @@ -0,0 +1,795 @@ +import torch +import triton +import triton.language as tl +# from .utils import _strides, get_padded_headsize +from typing import Optional, Sequence, Tuple, Union +from HybridTensor.triton.triton_flashattn_decode import maybe_contiguous + +def get_padded_headsize(size): + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + return padded_d_model + + +def _strides(x: torch.Tensor, *stride_names: str): + if x is None: + return {f"stride_{s}": 0 for i, s in enumerate(stride_names)} + + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + batch_group_index, # new, [B, top_k_groups] + K_new, + V_new, + Cache_seqlens, + Cache_batch_idx, + Alibi_slopes, + stride_qz, stride_qm, stride_qg, stride_qh, stride_qd, + stride_kz, stride_kn, stride_kg, stride_kh, stride_kd, + stride_vz, stride_vn, stride_vg, stride_vh, stride_vd, + stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d, + stride_mzhg, stride_m2, stride_ms, stride_mm, + # add strides for batch_group_index + stride_bgz, stride_bgk, # new + stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d, + stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d, + stride_az, stride_ah, + Z, + N_CTX_Q, + N_CTX_K, + N_CTX_NEW, + BLOCK_N_PER_SPLIT, + H_q: tl.constexpr, # Total Heads + H_kv: tl.constexpr, # Total Heads for K/V + G_q: tl.constexpr, # Number of groups + TOP_K_GROUPS: tl.constexpr, # new, Number of groups to be selected for each batch + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_CACHE_SEQLENs: tl.constexpr, + USE_CACHE_BATCH_IDX: tl.constexpr, + NEW_KV: tl.constexpr, + IS_GQA: tl.constexpr, + IS_CAUSAL: tl.constexpr, + USE_ALIBI: tl.constexpr, +): + # Padding + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + if PADDED_HEAD: + d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL + + start_m = tl.program_id(0).to(tl.int32) + off_zhg = tl.program_id(1).to(tl.int64) + splitk_idx = tl.program_id(2).to(tl.int64) + HEAD_RATIO: tl.constexpr = H_q // H_kv # num of actual heads per group + + # off_z = off_zhg // (H_q * G_q) + # off_h_q = (off_zhg // G_q) % H_q + + # off_z = off_zhg // (H_q * TOP_K_GROUPS) # index of current batch + # off_k = (off_zhg //H_q) % TOP_K_GROUPS # index of selected group + # off_h_q = off_zhg % H_q # index of current head within group (0 to H_q-1) + + off_z = off_zhg // (HEAD_RATIO * TOP_K_GROUPS) # index of current batch + off_k = (off_zhg //HEAD_RATIO) % TOP_K_GROUPS # index of selected group + off_g_q = off_zhg % G_q # G_q always 1, off_g_q = 0 + + # load the batch group index + batch_group_index_ptr = batch_group_index + off_z * stride_bgz + off_k * stride_bgk + group_idx = tl.load(batch_group_index_ptr, mask=True, other=0).to(tl.int64) + + # off_h_q = off_zhg % HEAD_RATIO # index of current head within group (0 to H_q-1) + off_h_q = group_idx * HEAD_RATIO + off_zhg % HEAD_RATIO # index of current head within group (0 to H_q-1) + + # logical_off_zhg = off_z * (G_q * HEAD_RATIO) + off_g_q * HEAD_RATIO + off_h_q + + + + # pick batch index + if USE_CACHE_BATCH_IDX: + cache_batch_idx = tl.load(Cache_batch_idx + off_z) + else: + cache_batch_idx = off_z + + # Load ALiBi slope if enabled + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(Alibi_slopes + a_offset) + else: + alibi_slope = None + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_CACHE_SEQLENs: + cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z) + if NEW_KV: + kv_len = cache_seqlen_last_idx + N_CTX_NEW + else: + kv_len = cache_seqlen_last_idx + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + if IS_GQA: + k_head_idx = off_h_q // HEAD_RATIO + v_head_idx = k_head_idx + else: + k_head_idx = off_h_q + v_head_idx = off_h_q + + # calculate base offset + k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg + v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg + + # Copy new Keys and Values into Cache + if NEW_KV: + knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g + + # Determine the starting position for new data in the cache + if USE_CACHE_SEQLENs: + start_idx = tl.load(Cache_seqlens + off_z) + else: + start_idx = N_CTX_K - N_CTX_NEW + + # Copy new Keys + for i in range(0, N_CTX_NEW, BLOCK_N): + # Load from K_new + k_new_block = tl.load( + knew_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + + (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + other=0 + ) + + # Store to K + tl.store( + k_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + + (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, + k_new_block, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + ) + + # Copy new Values + vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g + for i in range(0, N_CTX_NEW, BLOCK_N): + # Load from V_new + v_new_block = tl.load( + vnew_base + + (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d, + mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), + other=0 + ) + + # Store to V + tl.store( + v_base + + (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, + v_new_block, + mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), + ) + + # Q_block_ptr = tl.make_block_ptr( + # base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg, + # shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL), + # strides=(stride_qm, stride_qd), + # offsets=(start_m * BLOCK_M, 0), + # block_shape=(BLOCK_M, BLOCK_DMODEL), + # order=(1, 0), + # ) + + # K_block_ptr = tl.make_block_ptr( + # base=k_base, + # shape=(ACTUAL_BLOCK_DMODEL, hi), + # strides=(stride_kd, stride_kn), + # offsets=(0, lo), + # block_shape=(BLOCK_DMODEL, BLOCK_N), + # order=(0, 1), + # ) + # V_block_ptr = tl.make_block_ptr( + # base=v_base, + # shape=(hi, ACTUAL_BLOCK_DMODEL), + # strides=(stride_vn, stride_vd), + # offsets=(lo, 0), + # block_shape=(BLOCK_N, BLOCK_DMODEL), + # order=(1, 0), + # ) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h_q.to(tl.int64) * stride_qh + off_z.to(tl.int64) * stride_qz + off_g_q.to(tl.int64) * stride_qg, # base calc uses int64 + shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qd), + offsets=((start_m * BLOCK_M).to(tl.int32), 0), # <<< Cast offset to int32 + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + # Ensure shapes/bounds used in K/V pointers are int32 where appropriate + K_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(ACTUAL_BLOCK_DMODEL, hi.to(tl.int32)), # <<< Ensure hi is int32 for shape + strides=(stride_kd, stride_kn), + offsets=(0, lo.to(tl.int32)), # <<< Ensure lo is int32 for offset + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi.to(tl.int32), ACTUAL_BLOCK_DMODEL), # <<< Ensure hi is int32 for shape + strides=(stride_vn, stride_vd), + offsets=(lo.to(tl.int32), 0), # <<< Ensure lo is int32 for offset + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + + + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) + q = (q * qk_scale).to(q.dtype) + if PADDED_HEAD: + q = tl.where(d_mask[None, :], q, 0.0) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + 1, + BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL, + Q.dtype.element_ty, + 0, + ) + if PADDED_HEAD: + k = tl.where(d_mask[:, None], k, 0.0) + v = tl.where(d_mask[None, :], v, 0.0) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # noqa: F821 + + if USE_ALIBI: + row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += (alibi_bias * 1.44269504) + + # Apply causal mask if IS_CAUSAL is True + if IS_CAUSAL: + row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_Q - kv_len + causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + if IS_CAUSAL: + alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf"))) + else: + alpha = tl.math.exp2(m_i - m_i_new) + # cause of nan because subtracting infs + if IS_CAUSAL: + qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf")) + else: + qk = qk - m_i_new[:, None] + + p = tl.math.exp2(qk) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p.to(v.dtype), v) + + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + # write back O + O_block_ptr = tl.make_block_ptr( + # base=Out_splitK + logical_off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, BLOCK_DMODEL), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0, ), + ) + # Write metadata for split-K reduction + # Metadata_ptr = (Metadata + logical_off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + tl.arange(0, BLOCK_M)) + Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + tl.arange(0, BLOCK_M)) + + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + + +@triton.jit +def load_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) + + return k, v + + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + batch_group_index, # [B, top_k_groups] + stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_k, + stride_mzhg, stride_m2, stride_ms, stride_mm, + stride_bgz, stride_bgk, + stride_oz, stride_oh, stride_og, stride_om, stride_ok, + stride_lse_zhg, stride_lse_m, + M_ceil: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, # Number of query heads + G: tl.constexpr, # Number of query groups, fixed to 1 + H_kv: tl.constexpr, # Number of key/value heads + TOP_K_GROUPS: tl.constexpr, + split_k: tl.constexpr, + splitK_pow2: tl.constexpr, + use_mask: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): + off_zhg = tl.program_id(0).to(tl.int64) + off_m = tl.program_id(1).to(tl.int64) + off_k = tl.program_id(2).to(tl.int64) + HEAD_RATIO: tl.constexpr = H // H_kv # num of actual heads per group + + # off_z = off_zhg // (H * G) + # off_h = (off_zhg // G) % H + off_g = off_zhg % G + + off_z = off_zhg // (HEAD_RATIO * TOP_K_GROUPS) # index of current batch + off_bk = (off_zhg //HEAD_RATIO) % TOP_K_GROUPS + + # load the batch group index + batch_group_index_ptr = batch_group_index + off_z * stride_bgz + off_bk * stride_bgk + group_idx = tl.load(batch_group_index_ptr, mask=True, other=0).to(tl.int64) + + # off_h = off_zhg % H # index of current head within group (0 to H-1) + off_h = group_idx * HEAD_RATIO + off_zhg % HEAD_RATIO + + # logical_off_zhg = off_z * (G * H) + off_g * H + off_h + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + # Metadata_ptr = (Metadata + stride_mzhg * logical_off_zhg + spk_idx * stride_ms + off_m * stride_mm) + Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm) + + # o_ptr = (Out_splitK + logical_off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + + # stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) + + o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + + if IS_CAUSAL: + l_m_offset = l_m - g_m + alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0) + else: + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + + if IS_CAUSAL: + # Avoid division by zero + g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) + acc_out = tl.sum(acc, axis=0) / g_sum_safe + else: + acc_out = tl.sum(acc, axis=0) / g_sum + + # Store output + Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m + + off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) + tl.store(Out_ptr, acc_out) + + # Store lse + # l_ptrs = LSE + logical_off_zhg * stride_lse_zhg + off_m + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + + if IS_CAUSAL: + lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[..., 0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale = scale_shift_ui8[..., 0:2].view(torch.float16) + shift = scale_shift_ui8[..., 2:4].view(torch.float16) + + kv_ui8 = k_ui8[..., ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out[..., ::2] = k1_f16 + out[..., 1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + +def select_gqa(q, k, v, + sm_scale, causal, + alibi_slopes, layout, + cache_seqlens, cache_batch_idx, + new_kv, k_new, v_new, + batch_group_index): + + # batch_group_index: Index used to select the groups to be activated, shape [batch_size, top_k_groups] + + # kernel config + BLOCK_M = 16 + BLOCK_N = 64 + SPLIT_K = None + NUM_QUANT_GROUPS = 1 + + # kernels expects "bsghd" + original_layout = layout + if layout == "bshd": + q=q.unsqueeze(2) + k=k.unsqueeze(2) + v=v.unsqueeze(2) + if new_kv: + k_new = k_new.unsqueeze(2) + v_new = v_new.unsqueeze(2) + layout = "bsghd" + elif layout == "bhsd": + q=q.permute(0, 2, 1, 3).unsqueeze(2) + k=k.permute(0, 2, 1, 3).unsqueeze(2) + v=v.permute(0, 2, 1, 3).unsqueeze(2) + if new_kv: + k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2) + v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2) + layout = "bsghd" + elif layout == "bsghd": + pass + elif layout is None: + raise ValueError("Layout not given") + assert layout == "bsghd" + + # get dims + batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape + _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape + _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape + num_actual_groups = heads_per_group_k + heads_per_actual_group = heads_per_group_q // num_actual_groups + + top_k_groups = batch_group_index.shape[1] + + num_actual_groups = heads_per_group_k + heads_per_actual_group = heads_per_group_q // num_actual_groups + + assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}" + + # get padded size + dim_padded = get_padded_headsize(dim_k) + + # Handle MQA/GQA case + if heads_per_group_q > heads_per_group_k: + is_gqa = True + elif heads_per_group_q < heads_per_group_k: + raise ValueError("heads_per_group_q < heads_per_group_k") + else: + is_gqa = False + + assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}" + + if SPLIT_K is not None: + split_k = SPLIT_K + else: + # Use heuristics + split_k = get_split_k(batch_size, num_actual_groups, heads_per_actual_group, seqlen_k) # NOTE: should the split think about seqlens? + + seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M + out_splitk = torch.empty([batch_size * num_actual_groups * heads_per_actual_group, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device) + metadata = torch.empty([batch_size * num_actual_groups * heads_per_actual_group, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device) + + + lse = torch.empty((batch_size * num_actual_groups * heads_per_actual_group, seqlen_q), device=q.device, dtype=torch.float32) + + # grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k) + grid_fwd = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * top_k_groups * heads_per_actual_group, split_k) + # print(f"grid_fwd: {grid_fwd}") + + num_warps = 1 + split_size = (seqlen_k + split_k - 1) // split_k + use_cache_seqlens = cache_seqlens is not None + # just before you launch either kernel + assert batch_group_index.is_contiguous() + stride_bgz = batch_group_index.stride(0) # = top_k_groups + stride_bgk = 1 # NOT batch_group_index.stride(1)! + + # TODO: enable quantization + _fwd_kernel_splitK[grid_fwd]( + Q=q, + K=k, + V=v, + sm_scale=sm_scale, + Out_splitK=out_splitk, + Metadata=metadata, + batch_group_index=batch_group_index, + K_new = k_new, + V_new = v_new, + Cache_seqlens=cache_seqlens, + Cache_batch_idx=cache_batch_idx, + Alibi_slopes=alibi_slopes, + **_strides(q, "qz", "qm", "qg", "qh", "qd"), + **_strides(k, "kz", "kn", "kg", "kh", "kd"), + **_strides(v, "vz", "vn", "vg", "vh", "vd"), + **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + # add strides for batch_group_index + # **_strides(batch_group_index, "bgz", "bgk"), # new, bgz: stride for batch, bgk: stride for top_k_groups + stride_bgz=stride_bgz, stride_bgk=stride_bgk, + **_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"), + **_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"), + **_strides(alibi_slopes, "az", "ah"), + Z=batch_size, + H_q=heads_per_group_q, + H_kv=heads_per_group_k, + G_q=n_group_q, + TOP_K_GROUPS = int(top_k_groups), + N_CTX_Q=seqlen_q, + N_CTX_K=seqlen_k, + N_CTX_NEW=k_new.shape[1] if new_kv else None, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=dim_padded, + ACTUAL_BLOCK_DMODEL=dim_k, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, + USE_CACHE_SEQLENs=use_cache_seqlens, + USE_CACHE_BATCH_IDX=cache_batch_idx is not None, + NEW_KV=new_kv, + IS_GQA=is_gqa, + IS_CAUSAL=causal, + USE_ALIBI=False if alibi_slopes is None else True, + num_warps=num_warps, + num_stages=1, + ) + + # out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype) + out = torch.zeros((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if batch_size * num_actual_groups * heads_per_actual_group * seqlen_q >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert dim_padded % k_block_num == 0 + k_block_size = dim_padded // k_block_num + grid_reduce = (batch_size * top_k_groups * heads_per_actual_group, seqlen_q, k_block_num) + + _splitK_reduce[grid_reduce]( + out_splitk, + metadata, + out, + lse, + batch_group_index, + **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + # **_strides(batch_group_index, "bgz", "bgk"), # new, bgz: stride for batch, bgk: stride for top_k_groups + stride_bgz=stride_bgz, stride_bgk=stride_bgk, + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=seqlen_q_ceil, + BLOCK_SIZE=k_block_size, + G= n_group_q, + TOP_K_GROUPS = int(top_k_groups), + H=heads_per_group_q, + H_kv=heads_per_group_k, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + IS_CAUSAL=causal, + num_warps=4) + + lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q]) + if q.ndim == 4: + # BMGHK -> BMHK + assert n_group_q == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if seqlen_k == 0: + out.zero_() + out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous() + + # output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q + if original_layout == "bshd": + # out=out.transpose(1, 2).contiguous() # this screws up heads and data. + # the data is laid out properly. Just need to reshape dims + out = out.reshape(batch_size, seqlen_q, -1, dim_padded) + + return out.narrow(-1, 0, dim_k), lse \ No newline at end of file diff --git a/HybridTensor/triton/triton_flashattn_decode.py b/HybridTensor/triton/triton_flashattn_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..0c46ebe7f4c1829f305c10394d47d629d7e3d0f8 --- /dev/null +++ b/HybridTensor/triton/triton_flashattn_decode.py @@ -0,0 +1,764 @@ +import torch +import triton +import triton.language as tl +from typing import Optional, Sequence, Tuple, Union + +# from flash_attn.utils import _strides, get_padded_headsize + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + +def get_padded_headsize(size): + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + return padded_d_model + + +def _strides(x: torch.Tensor, *stride_names: str): + if x is None: + return {f"stride_{s}": 0 for i, s in enumerate(stride_names)} + + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + K_new, + V_new, + Cache_seqlens, + Cache_batch_idx, + Alibi_slopes, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qd, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kd, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vd, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_d, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_kn_z, + stride_kn_n, + stride_kn_g, + stride_kn_h, + stride_kn_d, + stride_vn_z, + stride_vn_n, + stride_vn_g, + stride_vn_h, + stride_vn_d, + stride_az, + stride_ah, + Z, + N_CTX_Q, + N_CTX_K, + N_CTX_NEW, + BLOCK_N_PER_SPLIT, + H_q: tl.constexpr, # number of heads in query + H_kv: tl.constexpr, # number of heads in key/value + G_q: tl.constexpr, # number of groups in query, this results in 1 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_CACHE_SEQLENs: tl.constexpr, + USE_CACHE_BATCH_IDX: tl.constexpr, + NEW_KV: tl.constexpr, + IS_GQA: tl.constexpr, + IS_CAUSAL: tl.constexpr, + USE_ALIBI: tl.constexpr, +): + # Padding + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + if PADDED_HEAD: + d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL + + start_m = tl.program_id(0).to(tl.int32) + off_zhg = tl.program_id(1).to(tl.int64) + off_z = (off_zhg // (H_q * G_q)) + off_h_q = ((off_zhg // G_q) % H_q) + off_g_q = (off_zhg % G_q) # G_q always 1, so off_g_q = 0 + splitk_idx = tl.program_id(2).to(tl.int64) + + # pick batch index + if USE_CACHE_BATCH_IDX: + cache_batch_idx = tl.load(Cache_batch_idx + off_z).to(tl.int64) + else: + cache_batch_idx = off_z + + # Load ALiBi slope if enabled + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(Alibi_slopes + a_offset) + else: + alibi_slope = None + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_CACHE_SEQLENs: + cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z).to(tl.int32) + if NEW_KV: + kv_len = (cache_seqlen_last_idx + N_CTX_NEW).to(tl.int32) + else: + kv_len = cache_seqlen_last_idx + else: + kv_len = (N_CTX_K).to(tl.int32) + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) #.to(tl.int64) + + HEAD_RATIO: tl.constexpr = H_q // H_kv + if IS_GQA: + k_head_idx = off_h_q // HEAD_RATIO + v_head_idx = k_head_idx + else: + k_head_idx = off_h_q + v_head_idx = off_h_q + + # calculate base offset + k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg + v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg + + # Copy new Keys and Values into Cache + if NEW_KV: + knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g + + # Determine the starting position for new data in the cache + if USE_CACHE_SEQLENs: + start_idx = tl.load(Cache_seqlens + off_z) + else: + start_idx = N_CTX_K - N_CTX_NEW + + # Copy new Keys + for i in range(0, N_CTX_NEW, BLOCK_N): + # Load from K_new + k_new_block = tl.load( + knew_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + + (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + other=0 + ) + + # Store to K + tl.store( + k_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + + (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, + k_new_block, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + ) + + # Copy new Values + vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g + for i in range(0, N_CTX_NEW, BLOCK_N): + # Load from V_new + v_new_block = tl.load( + vnew_base + + (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d, + mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), + other=0 + ) + + # Store to V + tl.store( + v_base + + (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, + v_new_block, + mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), + ) + + # Q_block_ptr = tl.make_block_ptr( + # base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg, + # shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL), + # strides=(stride_qm, stride_qd), + # offsets=(start_m * BLOCK_M, 0), + # block_shape=(BLOCK_M, BLOCK_DMODEL), + # order=(1, 0), + # ) + + # K_block_ptr = tl.make_block_ptr( + # base=k_base, + # shape=(ACTUAL_BLOCK_DMODEL, hi), + # strides=(stride_kd, stride_kn), + # offsets=(0, lo), + # block_shape=(BLOCK_DMODEL, BLOCK_N), + # order=(0, 1), + # ) + # V_block_ptr = tl.make_block_ptr( + # base=v_base, + # shape=(hi, ACTUAL_BLOCK_DMODEL), + # strides=(stride_vn, stride_vd), + # offsets=(lo, 0), + # block_shape=(BLOCK_N, BLOCK_DMODEL), + # order=(1, 0), + # ) + + # FIX: make_block_ptr offsets should be int32 + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h_q.to(tl.int64) * stride_qh + off_z.to(tl.int64) * stride_qz + off_g_q.to(tl.int64) * stride_qg, # base calc uses int64 + shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qd), + offsets=((start_m * BLOCK_M).to(tl.int32), 0), # <<< Cast offset to int32 + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + # Ensure shapes/bounds used in K/V pointers are int32 where appropriate + K_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(ACTUAL_BLOCK_DMODEL, hi.to(tl.int32)), # <<< Ensure hi is int32 for shape + strides=(stride_kd, stride_kn), + offsets=(0, lo.to(tl.int32)), # <<< Ensure lo is int32 for offset + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi.to(tl.int32), ACTUAL_BLOCK_DMODEL), # <<< Ensure hi is int32 for shape + strides=(stride_vn, stride_vd), + offsets=(lo.to(tl.int32), 0), # <<< Ensure lo is int32 for offset + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) + q = (q * qk_scale).to(q.dtype) + if PADDED_HEAD: + q = tl.where(d_mask[None, :], q, 0.0) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N or (BLOCK_DMODEL != ACTUAL_BLOCK_DMODEL), # <-- here, + 1, + BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL, + Q.dtype.element_ty, + 0, + ) + if PADDED_HEAD: + k = tl.where(d_mask[:, None], k, 0.0) + v = tl.where(d_mask[None, :], v, 0.0) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # noqa: F821 + + if USE_ALIBI: + row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += (alibi_bias * 1.44269504) + + # Apply causal mask if IS_CAUSAL is True + if IS_CAUSAL: + row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_Q - kv_len + causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + if IS_CAUSAL: + alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf"))) + else: + alpha = tl.math.exp2(m_i - m_i_new) + # cause of nan because subtracting infs + if IS_CAUSAL: + qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf")) + else: + qk = qk - m_i_new[:, None] + + p = tl.math.exp2(qk) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p.to(v.dtype), v) + + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, BLOCK_DMODEL), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0, ), + ) + # Write metadata for split-K reduction + Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + + tl.arange(0, BLOCK_M)) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + + +@triton.jit +def load_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) + + return k, v + + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k: tl.constexpr, + splitK_pow2: tl.constexpr, + use_mask: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): + off_zhg = tl.program_id(0).to(tl.int64) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1).to(tl.int64) + off_k = tl.program_id(2).to(tl.int64) + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm) + + o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + + if IS_CAUSAL: + l_m_offset = l_m - g_m + alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0) + else: + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + + if IS_CAUSAL: + # Avoid division by zero + g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) + acc_out = tl.sum(acc, axis=0) / g_sum_safe + else: + acc_out = tl.sum(acc, axis=0) / g_sum + + # Store output + Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m + + off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) + tl.store(Out_ptr, acc_out) + + # Store lse + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + if IS_CAUSAL: + lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[..., 0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale = scale_shift_ui8[..., 0:2].view(torch.float16) + shift = scale_shift_ui8[..., 2:4].view(torch.float16) + + kv_ui8 = k_ui8[..., ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out[..., ::2] = k1_f16 + out[..., 1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + +def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, + layout, cache_seqlens, cache_batch_idx, + new_kv, k_new, v_new): + # kernel config + BLOCK_M = 16 + BLOCK_N = 64 + SPLIT_K = None + NUM_QUANT_GROUPS = 1 + + # kernels expects "bsghd" + original_layout = layout + if layout == "bshd": + q=q.unsqueeze(2) + k=k.unsqueeze(2) + v=v.unsqueeze(2) + if new_kv: + k_new = k_new.unsqueeze(2) + v_new = v_new.unsqueeze(2) + layout = "bsghd" + elif layout == "bhsd": + q=q.permute(0, 2, 1, 3).unsqueeze(2) + k=k.permute(0, 2, 1, 3).unsqueeze(2) + v=v.permute(0, 2, 1, 3).unsqueeze(2) + if new_kv: + k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2) + v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2) + layout = "bsghd" + elif layout == "bsghd": + pass + elif layout is None: + raise ValueError("Layout not given") + assert layout == "bsghd" + + # get dims + batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape + _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape + _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape + + num_actual_groups = heads_per_group_k + heads_per_actual_group = heads_per_group_q // num_actual_groups + + assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}" + + # get padded size + dim_padded = get_padded_headsize(dim_k) + + # Handle MQA/GQA case + if heads_per_group_q > heads_per_group_k: + is_gqa = True + elif heads_per_group_q < heads_per_group_k: + raise ValueError("heads_per_group_q < heads_per_group_k") + else: + is_gqa = False + + assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}" + + if SPLIT_K is not None: + split_k = SPLIT_K + else: + # Use heuristics + split_k = get_split_k(batch_size, num_actual_groups, heads_per_actual_group, seqlen_k) # NOTE: should the split think about seqlens? + + seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M + splitK_pow2 = triton.next_power_of_2(split_k) + + # out_splitk = torch.empty([batch_size * num_actual_groups * heads_per_actual_group, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device) + # metadata = torch.empty([batch_size * num_actual_groups * heads_per_actual_group, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device) + + out_splitk = torch.empty([batch_size * num_actual_groups * heads_per_actual_group, splitK_pow2, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device) + metadata = torch.empty([batch_size * num_actual_groups * heads_per_actual_group, 2, splitK_pow2, seqlen_q_ceil], dtype=torch.float32, device=q.device) + + lse = torch.empty((batch_size * num_actual_groups * heads_per_actual_group, seqlen_q), device=q.device, dtype=torch.float32) + grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * num_actual_groups * heads_per_actual_group, split_k) + # grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_k * heads_per_group_k, split_k) + + num_warps = 1 + split_size = (seqlen_k + split_k - 1) // split_k + use_cache_seqlens = cache_seqlens is not None + + # TODO: enable quantization + _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=sm_scale, + Out_splitK=out_splitk, + Metadata=metadata, + K_new = k_new, + V_new = v_new, + Cache_seqlens=cache_seqlens, + Cache_batch_idx=cache_batch_idx, + Alibi_slopes=alibi_slopes, + **_strides(q, "qz", "qm", "qg", "qh", "qd"), + **_strides(k, "kz", "kn", "kg", "kh", "kd"), + **_strides(v, "vz", "vn", "vg", "vh", "vd"), + **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"), + **_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"), + **_strides(alibi_slopes, "az", "ah"), + Z=batch_size, + H_q=heads_per_group_q, + H_kv=heads_per_group_k, + G_q=n_group_q, + N_CTX_Q=seqlen_q, + N_CTX_K=seqlen_k, + N_CTX_NEW=k_new.shape[1] if new_kv else None, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=dim_padded, + ACTUAL_BLOCK_DMODEL=dim_k, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, + USE_CACHE_SEQLENs=use_cache_seqlens, + USE_CACHE_BATCH_IDX=cache_batch_idx is not None, + NEW_KV=new_kv, + IS_GQA=is_gqa, + IS_CAUSAL=causal, + USE_ALIBI=False if alibi_slopes is None else True, + num_warps=num_warps, + num_stages=1, + ) + + out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if batch_size * num_actual_groups * heads_per_actual_group * seqlen_q >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert dim_padded % k_block_num == 0 + k_block_size = dim_padded // k_block_num + grid = (batch_size * num_actual_groups * heads_per_actual_group, seqlen_q, k_block_num) + + _splitK_reduce[grid]( + out_splitk, + metadata, + out, + lse, + **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=seqlen_q_ceil, + BLOCK_SIZE=k_block_size, + G=n_group_q, + H=heads_per_group_q, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + IS_CAUSAL=causal, + num_warps=4) + + lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q]) + if q.ndim == 4: + # BMGHK -> BMHK + assert n_group_q == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if seqlen_k == 0: + out.zero_() + out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous() + + # output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q + if original_layout == "bshd": + # out=out.transpose(1, 2).contiguous() # this screws up heads and data. + # the data is laid out properly. Just need to reshape dims + out = out.reshape(batch_size, seqlen_q, -1, dim_padded) + + return out.narrow(-1, 0, dim_k), lse \ No newline at end of file diff --git a/HybridTensor/triton/triton_utils.py b/HybridTensor/triton/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..498ccf0eb5ef302e5457946688ca75bdf5780ccb --- /dev/null +++ b/HybridTensor/triton/triton_utils.py @@ -0,0 +1,150 @@ +import torch + +import triton +import triton.language as tl + +def is_cuda(): + # return triton.runtime.driver.active.get_current_target().backend == "cuda" + return True + +def is_hip_mi200(): + # target = triton.runtime.driver.active.get_current_target() + # return target.backend == 'hip' and target.arch == 'gfx90a' + + return False + +def get_cuda_autotune_config(): + return [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), # works decent for smaller matrices + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + + # # Good config for fp8 inputs. + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + + # # New configurations optimized for smaller matrices + + # batch size 32 + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), # works the best for smaller matrices + + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + + # batch size 16 + + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), # works good for batch size 16 + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), # works good for batch size 16 + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), # works good for batch size 16 + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), # works good for batch size 16 + + ] + + +def get_hip_autotune_config(): + return [ + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=4, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, + num_warps=4, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, + num_warps=4, num_stages=0), + ] + + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + +from functools import partial + +def benchmark_fwd(a, b, bias, function, index_vec = None, iterations = 1000, print_result = False, out = None): + + kernel_name = function.__name__ + if index_vec is not None: + function = partial(function, index_vec=index_vec, out = out) + if bias is not None: + function = partial(function, bias = bias) + # warm up + for _ in range(10): + output = function(a, b) + + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(iterations): + output = function(a, b) + end.record() + torch.cuda.synchronize() + kernel_time = start.elapsed_time(end)/iterations + + if print_result: + print("Kernel", kernel_name, "Time: ", kernel_time, "ms") + + return output, kernel_time + + +def torch_sel_gemm_col(a, b, index_vec, bias = None, out = None): + # assumes relu activation + if bias is not None: + return torch.nn.functional.relu(torch.matmul(a, b[:, index_vec]) + bias[index_vec]) + else: + return torch.nn.functional.relu(torch.matmul(a, b[:, index_vec])) + +def torch_sel_gemm_row(a, b, index_vec, bias = None, out = None): + if bias is not None: + return torch.matmul(a, b[index_vec]) + bias # the entire bias vector is added to each row + else: + return torch.matmul(a, b[index_vec]) \ No newline at end of file diff --git a/HybridTensor/utils/__init__.py b/HybridTensor/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/HybridTensor/utils/__pycache__/__init__.cpython-39.pyc b/HybridTensor/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..028800f585e7e48a4e504c797ea394f30293d59b Binary files /dev/null and b/HybridTensor/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/HybridTensor/utils/__pycache__/activations.cpython-39.pyc b/HybridTensor/utils/__pycache__/activations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca6159a4a22d66d0cbd6fc11ecdc8d78c9a00e5f Binary files /dev/null and b/HybridTensor/utils/__pycache__/activations.cpython-39.pyc differ diff --git a/HybridTensor/utils/__pycache__/generation.cpython-39.pyc b/HybridTensor/utils/__pycache__/generation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d08d565373e0bb74aabb61a888411fd3752a208a Binary files /dev/null and b/HybridTensor/utils/__pycache__/generation.cpython-39.pyc differ diff --git a/HybridTensor/utils/__pycache__/profiling.cpython-39.pyc b/HybridTensor/utils/__pycache__/profiling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e775c8d9c0caeee80c0349ac84e2a337e167a4f Binary files /dev/null and b/HybridTensor/utils/__pycache__/profiling.cpython-39.pyc differ diff --git a/HybridTensor/utils/__pycache__/utils.cpython-39.pyc b/HybridTensor/utils/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad121b76c4b94a26441d42baf83c4f5147517448 Binary files /dev/null and b/HybridTensor/utils/__pycache__/utils.cpython-39.pyc differ diff --git a/HybridTensor/utils/activations.py b/HybridTensor/utils/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..f69dd4d428e05b1cf23324513c6afa0b04661492 --- /dev/null +++ b/HybridTensor/utils/activations.py @@ -0,0 +1,369 @@ +import os +import yaml + +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np + +MODELS = [ + "facebook/opt-125m", # 1 + "facebook/opt-350m", # 2 + "facebook/opt-1.3b", # 3 + "facebook/opt-2.7b", # 4 + "facebook/opt-6.7b", # 5 + "facebook/opt-13b", # 6 + "facebook/opt-30b", # 7 + "facebook/opt-66b", # 8 + "SparseLLM/ReluLLaMA-7B", # 9 + "SparseLLM/prosparse-llama-2-7b", # 10 + "meta-llama/Llama-2-7b-hf", # 11 + "meta-llama/Llama-2-13b-hf", # 12 + "-", + "meta-llama/Llama-3.1-8B", # 14 + "meta-llama/Llama-3.1-70B", # 15 + +] + + +CONFIGS = { + 'facebook/opt-175b':{ + 'num_layer': 95, + 'd':12288, + 'h': 96, + 'neurons': 49152 + }, + 'facebook/opt-66b':{ + 'num_layer': 64, + 'd':9216, + 'h': 72, + 'neurons': 36864, + 'layer_imp': "results/attn_importance/opt-66b_attn_importance.json" + }, + 'facebook/opt-30b':{ + 'num_layer': 48, + 'd':7168, + 'h': 56, + 'neurons': 28672, + 'layer_imp': "results/attn_importance/opt-30b_attn_importance.json" + }, + 'facebook/opt-13b':{ + 'num_layer': 40, + 'd':5120, + 'h': 40, + 'neurons': 20480, + 'layer_imp': "results/attn_importance/opt-13b_attn_importance.json" + }, + 'facebook/opt-6.7b':{ + 'num_layer': 32, + 'd':4096, + 'h': 32, + 'neurons': 16384, + 'layer_imp': "results/attn_importance/opt-6.7b_attn_importance.json" + }, + 'facebook/opt-2.7b':{ + 'num_layer': 32, + 'd':2560, + 'h': 32, + 'neurons': 10240 + }, + 'facebook/opt-1.3b':{ + 'num_layer': 24, + 'd':2048, + 'h': 32, + 'neurons': 8192 + }, + 'facebook/opt-350m':{ + 'num_layer': 24, + 'd':1024, + 'h': 16, + 'neurons': 4096 + }, + 'facebook/opt-125m':{ + 'num_layer': 12, + 'd':768, + 'h': 12, + 'neurons': 3072 + }, + 'SparseLLM/ReluLLaMA-7B':{ + 'num_layer': 32, + 'd': 4096, + 'h': 32, + 'neurons': 11008 + }, + 'SparseLLM/prosparse-llama-2-7b':{ + 'num_layer': 32, + 'd': 4096, + 'h': 32, + 'neurons': 11008 + }, + 'meta-llama/Llama-2-7b-hf':{ + 'num_layer': 32, + 'd': 4096, + 'h': 32, + 'neurons': 11008, + 'layer_imp': "results/attn_importance/Llama-2-7b-hf_attn_importance.json" + }, + 'meta-llama/Llama-2-13b-hf':{ + 'num_layer': 40, + 'd': 5120, + 'h': 40, + 'neurons': 13824, + 'layer_imp': "results/attn_importance/Llama-2-13b-hf_attn_importance.json" + }, + 'meta-llama/Llama-3.1-8B':{ + 'num_layer': 32, + 'd': 4096, + 'h': 32, + 'neurons': 14336, + 'layer_imp': "results/attn_importance/Llama-3.1-8B_attn_importance.json" + }, + 'meta-llama/Llama-3.1-70B':{ + 'num_layer': 80, + 'd': 8192, + 'h': 64, + 'neurons': 28672, + 'layer_imp': "results/attn_importance/Llama-3.1-70B_attn_importance.json" + } + +} + + +# This class is used to store the activation thresholds for each layer of the model. Used for evaluation purposes. +class ActivationThresholds: + def __init__(self, num_layers, attn_th=0.0, mlp_th = 0.0): + self.activation_threshold = {} + self.mlp_threshold = {} + for i in range(num_layers): + self.activation_threshold[i] = attn_th + self.mlp_threshold[i] = mlp_th + + def set_threshold(self, layer_idx, threshold): + self.activation_threshold[layer_idx] = threshold + + def get_threshold(self, layer_idx): + return self.activation_threshold[layer_idx] + + def save_thresholds(self, file_path): + with open(file_path, 'w') as file: + documents = yaml.dump(self.activation_threshold, file) + + # def load_thresholds(self, file_path): + # with open(file_path, 'r') as file: + # self.activation_threshold = yaml.load(file, Loader=yaml.FullLoader) + + def load_thresholds(self, sparsity_map): + """ + Load the activation thresholds from a given sparsity map. + + Parameters: + sparsity_map (dict): A dictionary containing layer indices and their corresponding activation thresholds. + """ + for layer_idx, threshold in sparsity_map.items(): + self.activation_threshold[layer_idx] = threshold + + @classmethod + def from_file(cls, file_path): + """Class method to create an instance from a YAML file.""" + with open(file_path, 'r') as file: + activation_threshold = yaml.load(file, Loader=yaml.FullLoader) + + # Create an instance and set its activation_threshold + instance = cls() + instance.activation_threshold = activation_threshold + return instance + + +def build_mlp_topk_lookup(data_path: str, batch_size: int, delta: int = 128) -> dict: + """ + Creates a lookup table from the CSV file 'mlp_act_batch_{batch_size}_stats.csv' in the given directory. + + The lookup maps each layer to a top-k value calculated as: + top_k = ceil((average_activation + std_activation + delta)) + + Parameters: + data_path (str): The path to the directory containing the CSV files. + batch_size (int): The batch size to use in the file name and for filtering. + delta (float): The additional value added to the activations before rounding. + + Returns: + dict: A mapping from layer id to computed top-k value. + + Raises: + FileNotFoundError: If the CSV file does not exist in the provided directory. + """ + file_name = f"mlp_act_batch_{batch_size}_stats.csv" + full_path = os.path.join(data_path, file_name) + + if not os.path.exists(full_path): + raise FileNotFoundError(f"File not found: {full_path}") + + # Read the CSV file into a DataFrame + df = pd.read_csv(full_path) + + def calc_top_k(avg, std, delta): + raw_value = avg + std + delta + # return int(np.ceil(raw_value /128) * 128) # Round up to the nearest multiple of 128 + return int(np.ceil(raw_value)) + + mlp_lookup = { + row["layer"]: calc_top_k(row["average_activation"], row["std_activation"], delta) + for _, row in df.iterrows() if row["batch_size"] == batch_size + } + + return mlp_lookup + + +def _update_hf_mlp_topk(model, mlp_lookup): + """ + Updates the top-k values in the model's MLP layers using the provided lookup table. + + Parameters: + model (HybridModel): The model to update. + mlp_lookup (dict): The lookup table mapping layer id to top-k value. + delta (int): The additional value added to the activations before rounding. + """ + for layer_idx, top_k in mlp_lookup.items(): + model.model.decoder.layers[layer_idx].mlp_act= int(top_k) + + +def identify_model_type(model_name): + """ + Identifies if the given model name is an OPT model or a Llama model. + + Args: + model_name (str): The name of the model. + + Returns: + str: "OPT" if the model is an OPT model, "Llama" if it's a Llama model, or "Unknown". + """ + if "opt" in model_name.lower(): + return "OPT" + elif "llama" in model_name.lower(): + return "Llama" + else: + return "Unknown" + + +OPT_MODELS = [ + "facebook/opt-125m", # 1 + "facebook/opt-350m", # 2 + "facebook/opt-1.3b", # 3 + "facebook/opt-2.7b", # 4 + "facebook/opt-6.7b", # 5 + "facebook/opt-13b", # 6 + "facebook/opt-30b", # 7 + "facebook/opt-66b" # 8 +] + + +OPT_CONFIGS = { + 'facebook/opt-175b':{ + 'num_layer': 95, + 'sp_config': None, + 'd':12288, + 'h': 96, + }, + 'facebook/opt-66b':{ + 'num_layer': 64, + 'd':9216, + 'h': 72, + }, + 'facebook/opt-30b':{ + 'num_layer': 48, + 'd':7168, + 'h': 56, + }, + 'facebook/opt-13b':{ + 'num_layer': 40, + 'd':5120, + 'h': 40, + }, + 'facebook/opt-6.7b':{ + 'num_layer': 32, + 'd':4096, + 'h': 32, + }, + 'facebook/opt-2.7b':{ + 'num_layer': 32, + 'd':2560, + 'h': 32, + }, + 'facebook/opt-1.3b':{ + 'num_layer': 24, + 'd':2048, + 'h': 32, + }, + 'facebook/opt-350m':{ + 'num_layer': 24, + 'd':1024, + 'h': 16, + }, + 'facebook/opt-125m':{ + 'num_layer': 12, + 'd':768, + 'h': 12, + }, +} + +''' +import seaborn as sns +def plot_average_activation(directory_path, model_name): + + + # Read all the .csv files in the specified directory and concatenate them into a single DataFrame + files = [os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith('.csv')] + df = pd.concat([pd.read_csv(f) for f in files]) + + # Set the seaborn style for better aesthetics + sns.set(style="whitegrid") + + total_neurons = OPT_CONFIGS[model_name]["d"] * 4 + # Compute the average activation percentage and standard deviation percentage + df['average_activation_percentage'] = (df['average_activation'] / total_neurons) * 100 + df['std_activation_percentage'] = (df['std_activation'] / total_neurons) * 100 + + # Create a color palette with a different color for each batch size + palette = sns.color_palette("husl", df['batch_size'].nunique()) + + # Initialize the matplotlib figure + plt.figure(figsize=(12, 8)) + + # Loop over each batch size and plot the average_activation_percentage with error bars + for i, (batch_size, group) in enumerate(df.groupby('batch_size')): + plt.errorbar( + group['layer'], + group['average_activation_percentage'], + yerr=group['std_activation_percentage'], + label=f'Batch Size {batch_size}', + capsize=3, + marker='o', + linestyle='-', + color=palette[i] + ) + + # Set y-axis ticks at every 10% increment + plt.yticks(range(0, 101, 10)) # Y-ticks from 0% to 100% in steps of 10% + + # Shaded region + plt.axhspan(80, plt.ylim()[1], facecolor='gray', alpha=0.1) + + # Set the labels and title of the plot + plt.xlabel('Layer', fontsize=14) + plt.ylabel('Average Activation Percentage (%)', fontsize=14) + plt.title(f'Model: {model_name} Average Activation Percentage vs Layer for Different Batch Sizes', fontsize=16) + + # Show legend + plt.legend(title='Batch Size', fontsize=12, title_fontsize=12) + + # Tight layout for better spacing + plt.tight_layout() + + # Save the image + plt.savefig('average_activation_analysis.png') + + # Display the plot + plt.show() + + + +''' diff --git a/HybridTensor/utils/configs/facebook/opt-1.3b_attn_config.yaml b/HybridTensor/utils/configs/facebook/opt-1.3b_attn_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9728d021b4ac04cd2e8b380565796c382b1e28ed --- /dev/null +++ b/HybridTensor/utils/configs/facebook/opt-1.3b_attn_config.yaml @@ -0,0 +1,24 @@ +0: 0 +1: 0 +2: 0 +3: 0 +4: 0 +5: 0 +6: 0 +7: 0 +8: 0 +9: 0 +10: 0 +11: 0 +12: 0 +13: 0 +14: 0 +15: 0 +16: 0 +17: 0 +18: 0 +19: 0 +20: 0 +21: 0 +22: 0 +23: 0 diff --git a/HybridTensor/utils/configs/facebook/opt-125m_attn_config.yaml b/HybridTensor/utils/configs/facebook/opt-125m_attn_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..887752123e1ce0d7c6a519b8a8e32d69e9bf7909 --- /dev/null +++ b/HybridTensor/utils/configs/facebook/opt-125m_attn_config.yaml @@ -0,0 +1,12 @@ +0: 0 +1: 0 +2: 0 +3: 0 +4: 0 +5: 0 +6: 0 +7: 0 +8: 0 +9: 0 +10: 0 +11: 0 diff --git a/HybridTensor/utils/configs/facebook/opt-13b_attn_config.yaml b/HybridTensor/utils/configs/facebook/opt-13b_attn_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6609dae8e9ae55f909bcc2b483c107d176c8d073 --- /dev/null +++ b/HybridTensor/utils/configs/facebook/opt-13b_attn_config.yaml @@ -0,0 +1,40 @@ +0: 0 +1: 0 +2: 0 +3: 0 +4: 0 +5: 0 +6: 0 +7: 0 +8: 0 +9: 0 +10: 0 +11: 0 +12: 0 +13: 0 +14: 0 +15: 0 +16: 0 +17: 0 +18: 0 +19: 0 +20: 0 +21: 0 +22: 0 +23: 0 +24: 0 +25: 0 +26: 0 +27: 0 +28: 0 +29: 0 +30: 0 +31: 0 +32: 0 +33: 0 +34: 0 +35: 0 +36: 0 +37: 0 +38: 0 +39: 0 diff --git a/HybridTensor/utils/configs/facebook/opt-175b_attn_config.yaml b/HybridTensor/utils/configs/facebook/opt-175b_attn_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..00be6af5b34837e8a65c596cd36302f19dc5edd7 --- /dev/null +++ b/HybridTensor/utils/configs/facebook/opt-175b_attn_config.yaml @@ -0,0 +1,95 @@ +0: 0 +1: 0 +2: 0 +3: 0 +4: 0 +5: 0 +6: 0 +7: 0 +8: 0 +9: 0 +10: 0 +11: 0 +12: 0 +13: 0 +14: 0 +15: 0 +16: 0 +17: 0 +18: 0 +19: 0 +20: 0 +21: 0 +22: 0 +23: 0 +24: 0 +25: 0 +26: 0 +27: 0 +28: 0 +29: 0 +30: 0 +31: 0 +32: 0 +33: 0 +34: 0 +35: 0 +36: 0 +37: 0 +38: 0 +39: 0 +40: 0 +41: 0 +42: 0 +43: 0 +44: 0 +45: 0 +46: 0 +47: 0 +48: 0 +49: 0 +50: 0 +51: 0 +52: 0 +53: 0 +54: 0 +55: 0 +56: 0 +57: 0 +58: 0 +59: 0 +60: 0 +61: 0 +62: 0 +63: 0 +64: 0 +65: 0 +66: 0 +67: 0 +68: 0 +69: 0 +70: 0 +71: 0 +72: 0 +73: 0 +74: 0 +75: 0 +76: 0 +77: 0 +78: 0 +79: 0 +80: 0 +81: 0 +82: 0 +83: 0 +84: 0 +85: 0 +86: 0 +87: 0 +88: 0 +89: 0 +90: 0 +91: 0 +92: 0 +93: 0 +94: 0 diff --git a/HybridTensor/utils/configs/facebook/opt-2.7b_attn_config.yaml b/HybridTensor/utils/configs/facebook/opt-2.7b_attn_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5389c672b84d0dcb310b4c30d11957dbc40077a9 --- /dev/null +++ b/HybridTensor/utils/configs/facebook/opt-2.7b_attn_config.yaml @@ -0,0 +1,32 @@ +0: 0 +1: 0 +2: 0 +3: 0 +4: 0 +5: 0 +6: 0 +7: 0 +8: 0 +9: 0 +10: 0 +11: 0 +12: 0 +13: 0 +14: 0 +15: 0 +16: 0 +17: 0 +18: 0 +19: 0 +20: 0 +21: 0 +22: 0 +23: 0 +24: 0 +25: 0 +26: 0 +27: 0 +28: 0 +29: 0 +30: 0 +31: 0 diff --git a/HybridTensor/utils/configs/facebook/opt-30b_attn_config.yaml b/HybridTensor/utils/configs/facebook/opt-30b_attn_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89e3c3be31787fa4842b0046ae1bf510cada5034 --- /dev/null +++ b/HybridTensor/utils/configs/facebook/opt-30b_attn_config.yaml @@ -0,0 +1,48 @@ +0: 0 +1: 0 +2: 0 +3: 0 +4: 0 +5: 0 +6: 0 +7: 0 +8: 0 +9: 0 +10: 0 +11: 0 +12: 0 +13: 0 +14: 0 +15: 0 +16: 0 +17: 0 +18: 0 +19: 0 +20: 0 +21: 0 +22: 0 +23: 0 +24: 0 +25: 0 +26: 0 +27: 0 +28: 0 +29: 0 +30: 0 +31: 0 +32: 0 +33: 0 +34: 0 +35: 0 +36: 0 +37: 0 +38: 0 +39: 0 +40: 0 +41: 0 +42: 0 +43: 0 +44: 0 +45: 0 +46: 0 +47: 0 diff --git a/HybridTensor/utils/configs/facebook/opt-350m_attn_config.yaml b/HybridTensor/utils/configs/facebook/opt-350m_attn_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9728d021b4ac04cd2e8b380565796c382b1e28ed --- /dev/null +++ b/HybridTensor/utils/configs/facebook/opt-350m_attn_config.yaml @@ -0,0 +1,24 @@ +0: 0 +1: 0 +2: 0 +3: 0 +4: 0 +5: 0 +6: 0 +7: 0 +8: 0 +9: 0 +10: 0 +11: 0 +12: 0 +13: 0 +14: 0 +15: 0 +16: 0 +17: 0 +18: 0 +19: 0 +20: 0 +21: 0 +22: 0 +23: 0 diff --git a/HybridTensor/utils/configs/facebook/opt-6.7b_attn_config.yaml b/HybridTensor/utils/configs/facebook/opt-6.7b_attn_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3445958965867448fe427c2f6375e929e2fe3563 --- /dev/null +++ b/HybridTensor/utils/configs/facebook/opt-6.7b_attn_config.yaml @@ -0,0 +1,32 @@ +0: 0.8 +1: 0.8 +2: 0.8 +3: 0.8 +4: 0.8 +5: 0.8 +6: 0.8 +7: 0.8 +8: 0.8 +9: 0.8 +10: 0.8 +11: 0.8 +12: 0.8 +13: 0.8 +14: 0.8 +15: 0.8 +16: 0.8 +17: 0.8 +18: 0.8 +19: 0.8 +20: 0.8 +21: 0.8 +22: 0.8 +23: 0.8 +24: 0.8 +25: 0.8 +26: 0.8 +27: 0.8 +28: 0.8 +29: 0.8 +30: 0.8 +31: 0.8 diff --git a/HybridTensor/utils/configs/facebook/opt-66b_attn_config.yaml b/HybridTensor/utils/configs/facebook/opt-66b_attn_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e81b0def34ab69a99c0e97f33b47b16c08a91c6b --- /dev/null +++ b/HybridTensor/utils/configs/facebook/opt-66b_attn_config.yaml @@ -0,0 +1,64 @@ +0: 0 +1: 0 +2: 0 +3: 0 +4: 0 +5: 0 +6: 0 +7: 0 +8: 0 +9: 0 +10: 0 +11: 0 +12: 0 +13: 0 +14: 0 +15: 0 +16: 0 +17: 0 +18: 0 +19: 0 +20: 0 +21: 0 +22: 0 +23: 0 +24: 0 +25: 0 +26: 0 +27: 0 +28: 0 +29: 0 +30: 0 +31: 0 +32: 0 +33: 0 +34: 0 +35: 0 +36: 0 +37: 0 +38: 0 +39: 0 +40: 0 +41: 0 +42: 0 +43: 0 +44: 0 +45: 0 +46: 0 +47: 0 +48: 0 +49: 0 +50: 0 +51: 0 +52: 0 +53: 0 +54: 0 +55: 0 +56: 0 +57: 0 +58: 0 +59: 0 +60: 0 +61: 0 +62: 0 +63: 0 diff --git a/HybridTensor/utils/generation.py b/HybridTensor/utils/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..c609887fa689999710e5de84c4427e243bce485d --- /dev/null +++ b/HybridTensor/utils/generation.py @@ -0,0 +1,950 @@ +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 + +import gc +import time +from collections import namedtuple +from dataclasses import dataclass, field +from functools import partial +from typing import Callable, Optional, Sequence, Union + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor +from torch.profiler import ProfilerActivity, profile, record_function +from HybridTensor.utils.profiling import cuda_profiler + +# try: +# from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput +# except ImportError: +# GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores"]) +# SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores"]) + +GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores", "prefill_time", "decoding_time"]) +SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores", "prefill_time", "decoding_time"]) + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + max_seqlen: int + max_batch_size: int + seqlen_offset: int = 0 + batch_size_offset: int = 0 + key_value_memory_dict: dict = field(default_factory=dict) + lengths_per_sample: Optional[Tensor] = None + + def reset(self, max_seqlen, max_batch_size): + self.max_seqlen = max_seqlen + self.max_batch_size = max_batch_size + self.seqlen_offset = 0 + if self.lengths_per_sample is not None: + self.lengths_per_sample.zero_() + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf. Done in-place.""" + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(indices_to_remove, float("-Inf")) + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf. Done in-place.""" + if top_p <= 0.0 or top_p >= 1.0: + return + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits.masked_fill_(indices_to_remove, float("-inf")) + + +def sample(logits, top_k=1, top_p=0.0, temperature=1.0): + """Sample from top-k logits. + Arguments: + logits: Tensor of shape (batch_size, vocab_size) + """ + if top_k == 1: # Short-circuit for greedy decoding + return logits.argmax(dim=-1) + else: + if top_p > 0.0: + assert top_p <= 1.0, "top-p should be in (0, 1]." + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Safety check + logits_top, indices = torch.topk(logits, top_k, dim=-1) + if temperature != 1.0: + logits_top /= temperature + modify_logits_for_top_p_filtering(logits_top, top_p) + return indices[ + torch.arange(indices.shape[0], device=indices.device), + torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), + ] + else: + # Clone so that when we modify for top_p we don't change the original logits + logits_top = logits / temperature if temperature != 1.0 else logits.clone() + modify_logits_for_top_p_filtering(logits_top, top_p) + return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( + dim=-1 + ) + + +@torch.inference_mode() +def decode( + input_ids, + model, + max_length, + top_k=1, + top_p=0.0, + temperature=1.0, + eos_token_id=None, + teacher_outputs=None, + vocab_size=None, + tensor_parallel=1, + cg=False, + enable_timing=False, + inference_params=None, +): + """Decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. + We assume that all sequences in the same batch have the same length. + + Arguments: + input_ids: (batch, seq_len) + max_length: int + teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the + logits, the next token is taken from the teacher_outputs. Useful for testing. + Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + batch_size, seqlen_og = input_ids.shape + teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 + + if cg: + if not hasattr(model, "_decoding_cache"): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, + model._decoding_cache, + batch_size, + seqlen_og, + max_length, + tensor_parallel=tensor_parallel, + ) + inference_params = model._decoding_cache.inference_params + inference_params.reset(max_length, batch_size) + else: + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + + def get_logits(input_ids, inference_params): + decoding = inference_params.seqlen_offset > 0 + if decoding: + position_ids = torch.full( + (batch_size, 1), + inference_params.seqlen_offset, + dtype=torch.long, + device=input_ids.device, + ) + else: + position_ids = None + if not cg or not decoding: + # prefill stage + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=1, + ).logits.squeeze(dim=1) + else: + logits = model._decoding_cache.run( + input_ids, position_ids, inference_params.seqlen_offset + ).squeeze(dim=1) + return logits[..., :vocab_size] if vocab_size is not None else logits + + def sample_tokens(logits, inference_params): + if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: + token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) + else: + token = teacher_outputs[:, inference_params.seqlen_offset] + # return rearrange(token, "b -> b 1") + return token.unsqueeze(1) + + def should_stop(current_token, inference_params): + if inference_params.seqlen_offset == 0: + return False + if eos_token_id is not None and (current_token == eos_token_id).all(): + return True + if inference_params.seqlen_offset >= max_length - 1: + return True + return False + + start_prefill = torch.cuda.Event(enable_timing=enable_timing) + end_prefill = torch.cuda.Event(enable_timing=enable_timing) + start_decode = torch.cuda.Event(enable_timing=enable_timing) + end_decode = torch.cuda.Event(enable_timing=enable_timing) + scores, sequences = [], [input_ids] + + # prefill stage + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + start_prefill.record() + + scores.append(get_logits(sequences[-1], inference_params)) + inference_params.seqlen_offset += sequences[-1].shape[1] + sequences.append(sample_tokens(scores[-1], inference_params)) + + if enable_timing: + end_prefill.record() + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + + start_decode.record() + + # decode stage + while not should_stop(sequences[-1], inference_params): + scores.append(get_logits(sequences[-1], inference_params)) + inference_params.seqlen_offset += sequences[-1].shape[1] + sequences.append(sample_tokens(scores[-1], inference_params)) + + output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput + + if enable_timing: + end_decode.record() + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + prefill_time = start_prefill.elapsed_time(end_prefill) + decoding_time = start_decode.elapsed_time(end_decode) + # print(f"Prefill time: {prefill_time:.0f} ms") + # print(f"Decoding time: {decoding_time:.0f} ms") + + return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores), + prefill_time=prefill_time, decoding_time=decoding_time) + + return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores), + prefill_time=0, decoding_time=0) + + +@torch.inference_mode() +def decode_only( + input_ids, + model, + max_length, + top_k=1, + top_p=0.0, + temperature=1.0, + eos_token_id=None, + teacher_outputs=None, + vocab_size=None, + tensor_parallel=1, + cg=False, + enable_timing=False, + inference_params=None, +): + """Decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. + We assume that all sequences in the same batch have the same length. + The prefill stage is simulated using random values for efficient decode testing. + + Arguments: + input_ids: (batch, seq_len) + max_length: int + teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the + logits, the next token is taken from the teacher_outputs. Useful for testing. + Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + batch_size, seqlen_og = input_ids.shape + teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 + + if cg: + if not hasattr(model, "_decoding_cache"): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, + model._decoding_cache, + batch_size, + seqlen_og, + max_length, + tensor_parallel=tensor_parallel, + ) + inference_params = model._decoding_cache.inference_params + inference_params.reset(max_length, batch_size) + else: + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + + def get_logits(input_ids, inference_params): + decoding = inference_params.seqlen_offset > 0 + if decoding: + position_ids = torch.full( + (batch_size, 1), + inference_params.seqlen_offset, + dtype=torch.long, + device=input_ids.device, + ) + else: + position_ids = None + if not cg or not decoding: + # prefill stage + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=1, + ).logits.squeeze(dim=1) + else: + logits = model._decoding_cache.run( + input_ids, position_ids, inference_params.seqlen_offset + ).squeeze(dim=1) + return logits[..., :vocab_size] if vocab_size is not None else logits + + def sample_tokens(logits, inference_params): + if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: + token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) + else: + token = teacher_outputs[:, inference_params.seqlen_offset] + # return rearrange(token, "b -> b 1") + return token.unsqueeze(1) + + def should_stop(current_token, inference_params): + if inference_params.seqlen_offset == 0: + return False + if eos_token_id is not None and (current_token == eos_token_id).all(): + return True + if inference_params.seqlen_offset >= max_length - 1: + return True + return False + + def decode_loop(sequences, inference_params, og_seqlen): + inference_params.seqlen_offset = og_seqlen + while not should_stop(sequences[-1], inference_params): + scores.append(get_logits(sequences[-1], inference_params)) + inference_params.seqlen_offset += sequences[-1].shape[1] + sequences.append(sample_tokens(scores[-1], inference_params)) + + start_prefill = torch.cuda.Event(enable_timing=enable_timing) + end_prefill = torch.cuda.Event(enable_timing=enable_timing) + start_decode = torch.cuda.Event(enable_timing=enable_timing) + end_decode = torch.cuda.Event(enable_timing=enable_timing) + scores, sequences = [], [input_ids] + + # prefill stage + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + start_prefill.record() + + # scores.append(get_logits(sequences[-1], inference_params)) + + # prefill kv cache with random values to simulate the actual inference + # TODO: this speeds up simulation as we only care about decode stage + max_batch_size, max_seq_len, _, num_heads_kv, head_dim = inference_params.key_value_memory_dict[0].shape + kv = torch.rand(batch_size, seqlen_og, 2, num_heads_kv, head_dim, device=input_ids.device, dtype=torch.float16) + kv_cache = model.transformer.layers.mixer._update_kv_cache(kv, inference_params) + inference_params.seqlen_offset += sequences[-1].shape[1] + og_seqlen = inference_params.seqlen_offset + + # sequences.append(sample_tokens(scores[-1], inference_params)) + tokens = torch.ones((batch_size, 1), device=input_ids.device, dtype=torch.long) + sequences.append(tokens) + + + # use cuda profiler to measure decode time + if enable_timing: + end_prefill.record() + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + _, decode_time = cuda_profiler(get_logits, sequences[-1], inference_params, warmup_runs=2, timed_runs=10) + # _, decode_time = cuda_profiler(decode_loop, sequences, inference_params, og_seqlen, warmup_runs=10, timed_runs=10) + + # if enable_timing: + # end_prefill.record() + # if tensor_parallel > 1: + # torch.distributed.barrier() + # torch.cuda.synchronize() + + # start_decode.record() + + # decode stage + # while not should_stop(sequences[-1], inference_params): + # scores.append(get_logits(sequences[-1], inference_params)) + # inference_params.seqlen_offset += sequences[-1].shape[1] + # sequences.append(sample_tokens(scores[-1], inference_params)) + + output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput + + if enable_timing: + end_decode.record() + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + prefill_time = start_prefill.elapsed_time(end_prefill) + # decoding_time = start_decode.elapsed_time(end_decode) + decoding_time = decode_time + + # print(f"Prefill time: {prefill_time:.0f} ms") + # print(f"Decoding time: {decoding_time:.0f} ms") + + return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores), + prefill_time=prefill_time, decoding_time=decoding_time) + + return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores), + prefill_time=0, decoding_time=0) + + +def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0): + """Algorithm 1 from [1] + [1] Fast Inference from Transformers via Speculative Decoding + Yaniv Leviathan, Matan Kalman, Yossi Matias + https://arxiv.org/abs/2211.17192 + + Arguments: + logits: Tensor of shape (batch_size, seqlen + 1, vocab_size) + logits_draft: Tensor of shape (batch_size, seqlen, vocab_size) + tokens_draft: Tensor of shape (batch_size, seqlen) + Return: + tokens: Tensor of shape (batch_size, seqlen + 1) + num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1]. + For each sequence in the batch, the number of valid tokens that were sampled by + speculative sampling. + """ + batch, seqlen_p_1, vocab_size = logits.shape + seqlen = seqlen_p_1 - 1 + assert logits_draft.shape == (batch, seqlen, vocab_size) + assert tokens_draft.shape == (batch, seqlen) + assert tokens_draft.dtype in [torch.int64, torch.int32] + # TODO: if top_k = 1 we can simplify things and only work with indices + if top_p > 0.0: + assert top_p <= 1.0, "top-p should be in (0, 1]." + # Clone so that when we modify for top_p we don't change the original logits + logits = logits / temperature if temperature != 1.0 else logits.clone() + logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone() + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Safety check + modify_logits_for_top_k_filtering(logits, top_k) + modify_logits_for_top_k_filtering(logits_draft, top_k) + modify_logits_for_top_p_filtering(logits, top_p) + modify_logits_for_top_p_filtering(logits_draft, top_p) + probs = torch.softmax(logits, dim=-1) + probs_draft = torch.softmax(logits_draft, dim=-1) + gather = lambda probs, tokens: rearrange( + probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..." + ) + # (batch, seqlen) + accepted = torch.rand(batch, seqlen, device=probs.device) * gather( + probs_draft, tokens_draft + ) <= gather(probs[:, :-1], tokens_draft) + accepted_all = accepted.all(dim=-1) + # (batch,) + first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1)) + probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0) + # torch.multinomial can deal with unnormalized probabilities + # probs_diff /= probs_diff.sum(dim=-1, keepdim=True) + resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1) + resample_probs = rearrange( + resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)), + "b 1 d -> b d", + ) + resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1) # (batch,) + tokens = F.pad(tokens_draft, (0, 1)) + tokens[:, first_rejected_idx] = resample + return tokens, first_rejected_idx + 1 + + +@torch.inference_mode() +def decode_speculative( + input_ids, + model, + model_draft, + max_length, + speculative_lookahead=3, + top_k=1, + top_p=0.0, + temperature=1.0, + eos_token_id=None, + vocab_size=None, + tensor_parallel=1, + cg=False, + enable_timing=False, + debug=False, +): + """ + TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now. + + Speculative decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. + We assume that all sequences in the same batch have the same length. + + Arguments: + input_ids: (batch, seq_len) + max_length: int + Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + batch_size, seqlen_og = input_ids.shape + assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1" + assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id" + if cg: + if not hasattr(model_draft, "_decoding_cache"): + model_draft._decoding_cache = None + model_draft._decoding_cache = update_graph_cache( + model_draft, + model_draft._decoding_cache, + batch_size, + seqlen_og, + max_length, + # draft model needs to process either 1 or 2 tokens at a time + decoding_seqlens=(1, 2), + tensor_parallel=tensor_parallel, + ) + inference_params_draft = model_draft._decoding_cache.inference_params + inference_params_draft.reset(max_length, batch_size) + if not hasattr(model, "_decoding_cache"): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, + model._decoding_cache, + batch_size, + seqlen_og, + max_length, + decoding_seqlens=range(1, speculative_lookahead + 2), + tensor_parallel=tensor_parallel, + ) + inference_params = model._decoding_cache.inference_params + inference_params.reset(max_length, batch_size) + else: + inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + + def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False): + decoding = inference_params.seqlen_offset > 0 + if decoding: + seqlen = input_ids.shape[1] + # if inference_params.lengths_per_sample is None: + # TODO: in the case of batched decoding where each sequence has a different length, + # we need to compute the position_ids for each sequence using lengths_per_sample + if True: + cache_seqlens = torch.full( + (input_ids.shape[0],), + inference_params.seqlen_offset, + dtype=torch.int32, + device=input_ids.device, + ) + else: + cache_seqlens = inference_params.lengths_per_sample + position_ids = cache_seqlens[:, None] + torch.arange( + seqlen, dtype=torch.long, device=input_ids.device + ) + else: + position_ids = None + if not cg or not decoding: + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=num_last_tokens, + ).logits + else: + # NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1]. + # This might not be compatible the num_last_tokens used here. + assert num_last_tokens <= input_ids.shape[1] + logits = model._decoding_cache.run( + input_ids, position_ids, inference_params.seqlen_offset + )[:, -num_last_tokens:] + return logits[..., :vocab_size] if vocab_size is not None else logits + + def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1): + """Sample `num_tokens` tokens from the model, given the previous logits. + Also return the logits of the sampled tokens. + Arguments: + input_ids: (batch, seqlen) + Return: + tokens: (batch, num_tokens) + scores: (batch, num_tokens), which contains @previous_logits and the logits of the next + (num_tokens - 1) tokens. The logits of the last token isn't computed. + """ + assert num_tokens >= 1 + sequences, scores = [input_ids], [] + for i in range(num_tokens): + scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1]) + inference_params.seqlen_offset += sequences[-1].shape[1] + sequences.append(sample_fn(scores[-1]).unsqueeze(1)) + return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1) + + sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature) + sample_fn = partial(sample, **sampling_kwargs) + get_logits_main = partial(get_logits, model=model, cg=cg) + get_logits_draft = partial(get_logits, model=model_draft, cg=cg) + sample_tokens_main = partial( + sample_tokens, + get_logits_fn=get_logits_main, + sample_fn=sample_fn, + inference_params=inference_params, + ) + sample_tokens_draft = partial( + sample_tokens, + get_logits_fn=get_logits_draft, + sample_fn=sample_fn, + inference_params=inference_params_draft, + ) + + if debug: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + start = time.time() + + sequences, scores = [input_ids], [] + num_main_model_calls = 0 + num_draft_tokens = 0 + num_accepted_tokens_history = [] + if seqlen_og >= max_length - 1: + # Don't do speculative sampling, just sample 1 token from the model + tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1) + sequences.append(tokens) + scores.append(scores_new) + else: + # Sample from draft model, which produces @n_spec_tokens, and @model + # will then use to produce between 1 and 1 + @n_spec_tokens tokens. + # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length. + n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1) + tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens) + num_draft_tokens += n_spec_tokens + if debug: + scores_draft_ref = model_draft( + torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) + + # Evaluate the draft tokens with the model + logits = get_logits_main( + torch.cat([input_ids, tokens_draft], dim=1), + inference_params, + num_last_tokens=n_spec_tokens + 1, + ) + num_main_model_calls += 1 + if debug: + logits_ref = model( + torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((logits - logits_ref).abs().max()) + # breakpoint() + tokens, num_generated_tokens = sample_speculative( + logits, scores_draft, tokens_draft, **sampling_kwargs + ) + num_accepted_tokens_history.append(num_generated_tokens - 1) + if debug: + print(tokens) + print(num_generated_tokens) + # breakpoint() + # TODO: we're using the fact that batch_size == 1 + # TODO: check eos_token_id + sequences.append(tokens[:1, : num_generated_tokens[0]]) + scores.append(logits[:1, : num_generated_tokens[0]]) + # Note that @model has not evaluated the last sampled token yet, so we'll need to pass + # that in the next time we call @model. + num_generated = num_generated_tokens[0].item() + inference_params.seqlen_offset = seqlen_og + num_generated - 1 + inference_params_draft.seqlen_offset = ( + inference_params.seqlen_offset - 1 + if num_generated > 1 + else inference_params.seqlen_offset + ) + if debug: + cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) + scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits + print((scores[-1] - scores_ref[:, :-1]).abs().max()) + # breakpoint() + + while True: + # seqlen_offset is total length generated - 1 + if inference_params.seqlen_offset >= max_length - 1: + break + if inference_params.seqlen_offset >= max_length - 2: + # Don't do speculative sampling, just sample 1 token from the model + tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1) + sequences.append(tokens) + scores.append(scores_new) + break + # Sample from draft model + n_spec_tokens = min( + speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2 + ) + # If the main model accepts all the draft tokens, plus it samples one new token, + # then at the next iteration the draft model need to evaluate the logits of the last draft + # token and the logits of the newly sampled token. So here we pass in the last 2 tokens + # of sequences[-1]. + # This exception is when the main model rejects all the draft tokens, in which case we + # will only have 1 token to pass in. + tokens_draft, scores_draft = sample_tokens_draft( + sequences[-1][:, -2:], num_tokens=n_spec_tokens + ) + num_draft_tokens += n_spec_tokens + if debug: + scores_draft_ref = model_draft( + torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) + # breakpoint() + # Evaluate the draft tokens with the model + logits = get_logits_main( + torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1), + inference_params, + num_last_tokens=n_spec_tokens + 1, + ) # (batch, n_spec_tokens + 1, vocab_size) + num_main_model_calls += 1 + if debug: + logits_ref = model( + torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((logits - logits_ref).abs().max()) + # breakpoint() + tokens, num_generated_tokens = sample_speculative( + logits, scores_draft, tokens_draft, **sampling_kwargs + ) + num_accepted_tokens_history.append(num_generated_tokens - 1) + if debug: + print(tokens) + print(num_generated_tokens) + # breakpoint() + sequences.append(tokens[:1, : num_generated_tokens[0]]) + scores.append(logits[:1, : num_generated_tokens[0]]) + # We've evaluated 1 token from sequences[-1][:, -1:] above, plus + # num_generated_tokens[0].item() - 1 tokens from the draft model. + num_generated = num_generated_tokens[0].item() + inference_params.seqlen_offset += num_generated + inference_params_draft.seqlen_offset = ( + inference_params.seqlen_offset - 1 + if num_generated > 1 + else inference_params.seqlen_offset + ) + if debug: + cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) + scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits + print((scores[-1] - scores_ref[:, :-1]).abs().max()) + # breakpoint() + + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + print(f"Number of calls to main model: {num_main_model_calls}") + print( + f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%" + ) + sequences = torch.cat(sequences, dim=1) + scores = torch.cat(scores, dim=1) + if debug: + scores_ref = model(sequences).logits + print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max()) + output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput + return output_cls(sequences=sequences, scores=scores) + + +class GenerationMixin: + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + raise NotImplementedError + + def generate( + self, + input_ids, + max_length, + top_k=1, + top_p=0.0, + temperature=1.0, + return_dict_in_generate=False, + output_scores=False, + use_decode_only=False, + **kwargs, + ): + if use_decode_only: + output = decode_only( + input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs + ) + else: + output = decode( + input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs + ) + if not output_scores: + output.scores = None + return output if return_dict_in_generate else output.sequences + + +def allocate_inference_cache( + max_batch_size, + max_seqlen, + nheads, + headdim, + layers: Union[int, Sequence], + device, + dtype=torch.float16, +): + assert dtype in [torch.float16, torch.bfloat16, torch.float32] + kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) + if isinstance(layers, int): + layers = range(layers) + return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} + + +@dataclass +class DecodingCGCache: + max_batch_size: int = 0 + max_seqlen: int = 0 + device = None + dtype = None + callables: dict = field(default_factory=dict) + mempool = None + inference_params: Optional[InferenceParams] = None + run: Optional[Callable] = None + + +@torch.inference_mode() +def update_graph_cache( + model, + cache, + batch_size, + seqlen_og, + max_seqlen, + decoding_seqlens=(1,), + tensor_parallel=1, + dtype=None, + n_warmups=2, +): + if cache is None: + cache = DecodingCGCache() + param_example = next(iter(model.parameters())) + device = param_example.device + if dtype is None: + dtype = param_example.dtype + if ( + (device, dtype) != (cache.device, cache.dtype) + or batch_size > cache.max_batch_size + or max_seqlen > cache.max_seqlen + ): # Invalidate the cache + cache.callables = {} + cache.mempool = None + cache.inference_params = None + gc.collect() + cache.device, cache.dtype = device, dtype + cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen + if hasattr(model, "allocate_inference_cache"): + inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) + else: + headdim = getattr( + model.config, + "head_dim", + model.config.hidden_size // model.config.num_attention_heads, + ) + inf_cache = allocate_inference_cache( + batch_size, + max_seqlen, + model.config.num_attention_heads // tensor_parallel, + headdim, + model.config.num_hidden_layers, + device, + dtype, + ) + lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) + cache.inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_og, + key_value_memory_dict=inf_cache, + lengths_per_sample=lengths_per_sample, + ) + cache.mempool = torch.cuda.graphs.graph_pool_handle() + for decoding_seqlen in decoding_seqlens: + if (batch_size, decoding_seqlen) not in cache.callables: + cache.callables[batch_size, decoding_seqlen] = capture_graph( + model, + cache.inference_params, + batch_size, + max_seqlen, + decoding_seqlen=decoding_seqlen, + mempool=cache.mempool, + n_warmups=n_warmups, + ) + + def dispatch(input_ids, position_ids, seqlen): + batch_size, decoding_seqlen = input_ids.shape[:2] + return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) + + cache.run = dispatch + cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing + return cache + + +def capture_graph( + model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 +): + device = next(iter(model.parameters())).device + input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + seqlen_offset_og = inference_params.seqlen_offset + inference_params.seqlen_offset = max_seqlen - decoding_seqlen + inference_params.lengths_per_sample[:] = inference_params.seqlen_offset + + # Warmup before capture + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(n_warmups): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + s.synchronize() + # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, + # which requires that graph launch and non-captured launch to not overlap (I think, + # that's how I interpret the documentation). I'm not sure if this is required. + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.cuda.current_stream().wait_stream(s) + # Captures the graph + # To allow capture, automatically sets a side stream as the current stream in the context + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=mempool): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + + def run(new_input_ids, new_position_ids, seqlen): + inference_params.lengths_per_sample[:] = seqlen + input_ids.copy_(new_input_ids) + position_ids.copy_(new_position_ids) + graph.replay() + return logits.clone() + + inference_params.seqlen_offset = seqlen_offset_og + return run diff --git a/HybridTensor/utils/profiling.py b/HybridTensor/utils/profiling.py new file mode 100644 index 0000000000000000000000000000000000000000..68f8b8ac07cc89bdc28a29d35624bac622615537 --- /dev/null +++ b/HybridTensor/utils/profiling.py @@ -0,0 +1,270 @@ +import math +import torch +from functools import partial +import pandas as pd +import matplotlib.pyplot as plt + +from HybridTensor.utils.utils import get_gpu_name, create_results_directory + +from tqdm import tqdm # For progress bars + +# from HybridTensor.modules.MLP import StandardMLPBlock, SelectiveMLP, SelectiveMLPTriton +# from HybridTensor.utils.utils import sparse_index + +def benchmark_mlp_fwd(x, model, index_vec = None, iterations = 100, print_result = False): + + class_name_ = model.__class__.__name__ + if index_vec is not None: + model = partial(model, index_vec=index_vec) + # warm up, this also compiles the triton kernel before measuring the execution time + for _ in range(10): + out = model(x) + + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iterations): + out = model(x) + torch.cuda.synchronize() + end = torch.cuda.Event(enable_timing=True) + end.record() + torch.cuda.synchronize() + + elapsed_time = start.elapsed_time(end)/iterations + + if print_result: + print(f"{class_name_} Execution time: {elapsed_time} ms") + + return out, elapsed_time + +def generate_index_sizes(hidden_features): + index_sizes = [] + idx = 0 + while idx < hidden_features: + idx += 512 + index_sizes.append(min(idx, hidden_features)) + return index_sizes + + +def save_results_to_csv(df, filename_prefix='mlp_profiling_results', results_dir = create_results_directory("results")): + """ + Saves the profiling results DataFrame to a CSV file within the specified results directory. + + Parameters: + - df (pd.DataFrame): The DataFrame containing profiling results. + - filename_prefix (str): The prefix for the CSV filename. + - results_dir (Path): The Path object for the results directory. + """ + + # Retrieve the GPU name + gpu_name = get_gpu_name() + + # Define the filename with GPU name + filename = f"{filename_prefix}_{gpu_name}.csv" + + # Define the full path for the CSV file + csv_path = results_dir / filename + + # Save the DataFrame to the CSV file + df.to_csv(csv_path, index=False) + print(f"Results saved to {csv_path}") + +def plot_results(df, output_prefix='mlp_profiling', results_dir=create_results_directory("results")): + """ + Plots the profiling results and saves the plot image within the specified results directory. + + Parameters: + - df (pd.DataFrame): The DataFrame containing profiling results. + - output_prefix (str): The prefix for the plot filename. + - results_dir (Path): The Path object for the results directory. + """ + plt.figure(figsize=(12, 6)) + + # Plot Execution Time + plt.subplot(1, 2, 1) + plt.plot(df['index_size'], df['standard_time'], label='Standard MLP', marker='o') + plt.plot(df['index_size'], df['selective_cutlass_time'], label='Selective MLP Cutlass', marker='o') + plt.plot(df['index_size'], df['selective_triton_time'], label='Selective MLP Triton', marker='o') + plt.xlabel('Index Size') + plt.ylabel('Execution Time (ms)') + plt.title('Execution Time vs. Index Size') + plt.legend() + plt.grid(True) + + # Plot Speedup + plt.subplot(1, 2, 2) + plt.plot(df['index_size'], df['cutlass_speedup'], label='Cutlass Speedup', marker='o') + plt.plot(df['index_size'], df['triton_speedup'], label='Triton Speedup', marker='o') + plt.xlabel('Index Size') + plt.ylabel('Speedup') + plt.title('Speedup vs. Index Size') + plt.legend() + plt.grid(True) + + plt.tight_layout() + + # Retrieve the GPU name + gpu_name = get_gpu_name() + + # Define the filename with GPU name + plot_filename = f"{output_prefix}_{gpu_name}.png" + + # Define the full path for the plot image + plot_path = results_dir / plot_filename + + # Save the plot + plt.savefig(plot_path) + plt.show() + print(f"Plots saved as {plot_path}") + +def cuda_profiler(func, *args, warmup_runs=10, timed_runs=1000, **kwargs): + """ + Generic profiler function to measure execution time of a given function. + + Parameters: + - func: The function to be profiled. + - *args: Positional arguments to be passed to the function. + - warmup_runs: Number of warm-up runs (default: 2). + - timed_runs: Number of timed iterations (default: 10). + - **kwargs: Keyword arguments to be passed to the function. + + Returns: + - float: The average execution time of the function in milliseconds. + """ + # Warm-up phase + for _ in range(warmup_runs): + out = func(*args, **kwargs) + + # Synchronize before starting the timer to ensure accurate measurements + torch.cuda.synchronize() + + # Create CUDA events for measuring execution time + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Record execution times for the given number of runs + start_event.record() + for _ in range(timed_runs): + # Execute the function + out = func(*args, **kwargs) + + # Wait for the events to be completed + # torch.cuda.synchronize() + end_event.record() + torch.cuda.synchronize() + # Calculate elapsed time for this iteration + elapsed_time = start_event.elapsed_time(end_event) + + # Calculate average time per run + avg_time = elapsed_time / timed_runs + + return out, avg_time + + +from HybridTensor.triton.select_attn_v1 import select_attn +from HybridTensor.utils.utils import generate_BH_index +from HybridTensor.triton.references.attention_proj_sparse import qkv_proj_sparse, out_proj_sparse + +def _sim_cache_update(k, v, qkv, seq_len): + k[:, -1, ...] = qkv[:, :, 1] + v[:, -1, ...] = qkv[:, :, 2] + +def mha_inference_simulation(B, in_features, seq_len, head_density, active_density): + ''' + Simulates the execultion time of a standard MHA layer and a selective MHA layer with sparse projection and select_attn. + + Parameters: + - B: batch size + - in_features: number of features + - seq_len: sequence length + - head_density: the percentage of heads that are active per batch + - active_density: the percentage of active heads per layer (aggregate active heads in all batches) + + ''' + # Test parameters + H = in_features // 128 # Number of heads + G = 1 # Group size + M = 1 # Sequence length for queries + Mk = seq_len # Sequence length for keys/values + Kq = 128 # Embedding size for queries + Kkv = 128 # Embedding size for keys/values + + dtype = torch.float16 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + device = 'cuda:0' + dtype = torch.float16 + + x = torch.rand(B, in_features, dtype=dtype).to(device) + proj_dim = 3 * in_features + + # Define the Linear layer + qkv_project = torch.nn.Linear(in_features, proj_dim, dtype=dtype).to(device) + out_project = torch.nn.Linear(in_features, in_features, dtype=dtype).to(device) + + weight = torch.randn(3, H, Kkv, in_features, device=device, dtype=dtype) + bias = torch.randn(3, H, Kkv, device=device, dtype=dtype) + + n_active_heads = math.ceil(H * active_density) + head_idx = torch.randperm(H, device=device, dtype=torch.int32)[:n_active_heads] + + batch_idx = torch.stack([ + torch.arange(B, dtype=torch.int32, device=device) + for _ in range(n_active_heads) + ]) + + print(f"Batch size: {B}, Total heads: {H}, Features: {in_features}, Seq len: {seq_len}") + print(f"Total active heads: {n_active_heads}") + print(f"Head density in SelectAttn: {head_density}") + + print("====================================") + + # Inference simulation + qkv, qkv_project_time = cuda_profiler(qkv_project, x) + print(f"qkv projection time: {qkv_project_time:.3f} ms") + + qkv_sel, qkv_sel_proj_time = cuda_profiler(qkv_proj_sparse, x, weight, head_idx, batch_idx, bias) + print(f"qkv projection time: {qkv_sel_proj_time:.3f} ms") + + # Generate random tensors for q, k, v + q = torch.randn(B, M, G, H, Kq, dtype=dtype, device=device) + k = torch.randn(B, Mk, G, H, Kkv, dtype=dtype, device=device) + v = torch.randn(B, Mk, G, H, Kkv, dtype=dtype, device=device) + + # need to update kv cache with the new k, v + _sim_cache_update(k, v, qkv, seq_len) + _, kv_cache_update_time = cuda_profiler(_sim_cache_update, k, v, qkv, seq_len) + print(f"KV cache update time: {kv_cache_update_time:.3f} ms") + + scale = 1 / (Kq ** 0.5) + batch_head_index_1 = generate_BH_index(B, H, math.ceil(H * 1)) + + triton_sel_output, attn_time = cuda_profiler(select_attn, q, k, v, scale, batch_head_index_1) + print(f"Attention time: {attn_time:.3f} ms") + + batch_head_index_2 = generate_BH_index(B, H, math.ceil(H * head_density)) + triton_sel_output_2, select_attn_time = cuda_profiler(select_attn, q, k, v, scale, batch_head_index_2) + print(f"SelectAttn time: {select_attn_time:.3f} ms") + + triton_sel_output_2, view_time = cuda_profiler(triton_sel_output_2.view, B, in_features) + + # Out projection + out, out_project_time = cuda_profiler(out_project, triton_sel_output_2) + print(f"out projection time: {out_project_time:.3f} ms") + + standard_time = qkv_project_time + attn_time + out_project_time + select_time = qkv_project_time + select_attn_time + out_project_time + select_time_sparse_project = qkv_sel_proj_time + select_attn_time + out_project_time + + print("====================================") + print(f"Standard time: {standard_time:.3f} ms") + print(f"Select time: {select_time:.3f} ms") + print(f"Select time with sparse project: {select_time_sparse_project:.3f} ms") + print("====================================") + print(f"Selective Speedup: {standard_time / select_time:.3f}") + print(f"Selective Speedup with sparse project: {standard_time / select_time_sparse_project:.3f}") + + # free cuda memory + del qkv, qkv_sel, q, k, v, triton_sel_output, triton_sel_output_2, out + torch.cuda.empty_cache() + diff --git a/HybridTensor/utils/utils.py b/HybridTensor/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14d78b63566f0793e812f31dbb791d27a3606110 --- /dev/null +++ b/HybridTensor/utils/utils.py @@ -0,0 +1,324 @@ +import math +import numpy as np +import torch +import argparse +import os +from functools import partial +import pandas as pd +import matplotlib.pyplot as plt +from pathlib import Path +from tqdm import tqdm # For progress bars + +def arg_parser(): + parser = argparse.ArgumentParser(description='Inference benchmarking') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--hidden_features', type=int, default=32768) + parser.add_argument('--in_features', type=int, default=8192) + parser.add_argument('--model_index', type=int, default=5) + parser.add_argument('--seq_len', type=int, default=512) + parser.add_argument('--index_size', type=int, default=8192) + parser.add_argument('--head_density', type=float, default=0.25) + parser.add_argument('--attn_topk', type=float, default=0.5) + parser.add_argument('--print_results', type=bool, default=True) + parser.add_argument('--iterations', type=int, default=10) + parser.add_argument('--check_results', type=bool, default=False) + parser.add_argument('--results_dir', type=str, default='results') + parser.add_argument('--max_batch_size', type=int, default=32) + parser.add_argument('--max_seqlen', type=int, default=2048) + parser.add_argument('--bias', type=int, default=0) + parser.add_argument('--device', type=int, default=0) + parser.add_argument('--mode', type=str, default='row', choices=['row', 'col', 'auto']) + + return parser.parse_args() + +def initialize_distributed_environment(): + # Set environment variables for NCCL + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" + os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0" + + # Initialize the distributed process group + torch.distributed.init_process_group(backend="nccl", init_method="env://") + + # Set the device based on the rank of the current process + device = f"cuda:{torch.distributed.get_rank()}" + world_size = torch.distributed.get_world_size() + + # Set the current CUDA device to avoid operations being executed on the wrong GPU + torch.cuda.set_device(device) + + # You can return device, world_size, and any other relevant information + return device, world_size + +def get_gpu_name(): + if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) + # Clean the GPU name to make it filename-friendly + gpu_name_clean = gpu_name.replace(" ", "_").replace("/", "_").replace("\\", "_") + return gpu_name_clean + else: + return "CPU" + +def _get_device(device_id): + if torch.cuda.is_available(): + device = torch.device(f"cuda:{device_id}") + else: + device = torch.device("cpu") + + return device + +def extract_model_name(model_path: str) -> str: + return model_path.split("/")[-1] + +def create_results_directory(results_dir): + """ + Creates the results directory if it does not exist. + + Parameters: + - results_dir (str or Path): The path to the results directory. + + Returns: + - Path: The Path object representing the results directory. + """ + path = Path(results_dir).resolve() + + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + print(f"Created results directory at: {path}") + else: + # print(f"Results directory already exists at: {path}") + pass + + return path + +def ZeroIndex(index_vec, max_value): + # Create a set of all integers from 0 to max_value - 1 + all_integers = set(range(max_value)) + + # Convert index_vec to a set + index_set = set(index_vec.cpu().numpy()) + + # Subtract index_set from all_integers + remaining_integers = all_integers - index_set + + # Convert the result back to a tensor + zero_index = torch.tensor(list(remaining_integers), dtype=torch.int32, device='cuda') + + return zero_index + +def sparse_index(index_size, max_value, return_zero_index = False): + index_vec = torch.randperm(max_value, dtype=torch.int32, device='cuda')[:index_size] + index_vec, _ = torch.sort(index_vec) + if return_zero_index: + zero_index = ZeroIndex(index_vec, max_value) + else: + zero_index = None + return index_vec, zero_index + +# utility function to study activations + +def create_random_batches(labels, batch_size=32): + """ + Shuffles the labels and splits them into random batches. + + Parameters: + - labels (np.ndarray): The labels matrix of shape (212646, 16384). + - batch_size (int): The number of samples per batch. + + Returns: + - List[np.ndarray]: A list of batches, each containing `batch_size` rows. + """ + num_samples = labels.shape[0] + + # Generate a permutation of indices + shuffled_indices = np.random.permutation(num_samples) + + # Shuffle the labels matrix + shuffled_labels = labels[shuffled_indices] + + # Calculate the number of complete batches + num_batches = num_samples // batch_size + + # Split the shuffled labels into batches + batches = np.split(shuffled_labels[:num_batches * batch_size], num_batches) + + return batches + +def generate_BH_index(batch_size: int, heads: int, selected_heads: int, device = 'cuda'): + ''' + Generates a random list of selected heads for each batch. + + Args: + - batch_size (int): Number of batches. + - heads (int): Total number of heads. + - selected_heads (int): Number of heads to select for each batch. + + Returns: + - bh_index (torch.Tensor): Tensor of shape (batch_size * selected_heads, 2) where each row is (batch_idx, head_idx). + ''' + N_selected = batch_size * selected_heads + bh_index = torch.zeros((N_selected, 2), dtype=torch.int32, device=device) + + for batch_idx in range(batch_size): + selected_head_indices = torch.randperm(heads)[:selected_heads] + sorted_head_indices = torch.sort(selected_head_indices).values + for i, head_idx in enumerate(sorted_head_indices): + bh_index[batch_idx * selected_heads + i] = torch.tensor([batch_idx, head_idx], dtype=torch.int32) + + return bh_index + +def generate_random_BH_index(batch_size: int, heads: int, selected_heads: int, device = 'cuda'): + ''' + Generates a random list of selected heads for each batch. + + Args: + - batch_size (int): Number of batches. + - heads (int): Total number of heads. + - selected_heads (int): Number of heads to select for each batch. + + Returns: + - bh_index (torch.Tensor): Tensor of shape (batch_size, selected_heads) + ''' + bh_index = torch.zeros((batch_size, selected_heads), dtype=torch.int32, device=device) + + for batch_idx in range(batch_size): + selected_head_indices = torch.randperm(heads)[:selected_heads] + # sort the selected head indices + sorted_head_indices = torch.sort(selected_head_indices).values + bh_index[batch_idx] = sorted_head_indices + + return bh_index + +def generate_random_BG_index(batch_size: int, groups: int, selected_groups: int, device = 'cuda'): + ''' + Generates a random list of selected heads for each batch. + + Args: + - batch_size (int): Number of batches. + - heads (int): Total number of heads. + - selected_heads (int): Number of heads to select for each batch. + + Returns: + - bh_index (torch.Tensor): Tensor of shape (batch_size, selected_heads) + ''' + bg_index = torch.zeros((batch_size, selected_groups), dtype=torch.int32, device=device) + + for batch_idx in range(batch_size): + selected_group_indices = torch.randperm(groups)[:selected_groups] + # sort the selected head indices + sorted_group_indices = torch.sort(selected_group_indices).values + bg_index[batch_idx] = sorted_group_indices + + return bg_index + + + +def activation_stats_layer(test_batches, total_neurons, device): + """ + Calculates the average and standard deviation of activations across batches. + + Parameters: + - test_batches (List[np.ndarray] or List[torch.Tensor]): List of batches containing label data. + - total_neurons (int): Total number of neurons. + - device (torch.device): The device to perform computations on (e.g., 'cpu' or 'cuda'). + + Returns: + - avg_act (float): The average number of activations per batch. + - std_dev (float): The standard deviation of activations across batches. + """ + sum_activation = 0.0 # To accumulate the total activations + sum_activation_sq = 0.0 # To accumulate the squared activations + num_batches = len(test_batches) + + for i, batch in enumerate(test_batches): + # Convert batch to a PyTorch tensor if it's not already + if not isinstance(batch, torch.Tensor): + torch_labels = torch.tensor(batch, dtype=torch.float32, device=device) + else: + torch_labels = batch.to(device=device, dtype=torch.float32) + + # Binarize the labels: 1 if activation > 0, else 0 + binary_labels = (torch_labels > 0).int() + + # Sum activations per neuron across the batch + activation_counts = binary_labels.sum(dim=0) + + # Convert counts to binary: 1 if neuron is activated in the batch, else 0 + activated_neurons = (activation_counts > 0).int() + + # Total number of activated neurons in this batch + total_activations = activated_neurons.sum().item() + + # Accumulate sum and sum of squares + sum_activation += total_activations + sum_activation_sq += total_activations ** 2 + + # Optional: Print progress every 1000 batches + # if (i + 1) % 1000 == 0 or (i + 1) == num_batches: + # print(f"Processed {i + 1}/{num_batches} batches") + + # Calculate average activation + avg_act = sum_activation / num_batches + + # Calculate variance and standard deviation + variance = (sum_activation_sq / num_batches) - (avg_act ** 2) + std_dev = variance ** 0.5 + + # Display results + print(f"\nAverage activation: {avg_act:.2f} " + f"({(avg_act / total_neurons) * 100:.2f}% of total neurons)") + print(f"Standard deviation of activation: {std_dev:.2f}") + + return avg_act, std_dev + +def calculate_index_sizes(in_features): + """ + Calculate index sizes based on the given in_features. + The sizes are rounded up to the nearest multiple of 1024 + and generated in 5% increments up to 100% of total neurons. + + Args: + in_features (int): The number of input features. + + Returns: + List[int]: A list of index sizes rounded up to the nearest multiple of 1024. + """ + index_sizes = [] + total_neurons = in_features * 4 + + # Generate 20 percentages from 5% to 100% in increments of 5% + percentages = [i for i in range(5, 105, 5)] + + # Calculate index sizes and round up to the nearest multiple of 1024 + for p in percentages: + index_size = int((p / 100) * total_neurons) + index_size = math.ceil(index_size / 1024) * 1024 + index_sizes.append(index_size) + + return index_sizes + + +def compute_perplexity(model, dataloader, device): + total_loss = 0.0 + total_tokens = 0 + + with torch.no_grad(): + for batch in tqdm(dataloader, desc="Calculating Perplexity"): + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + + # Shift input_ids and labels for causal language modeling + labels = input_ids.clone() + # Replace padding tokens in labels by -100 so they are ignored in loss computation + labels[attention_mask == 0] = -100 + + outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + loss = outputs.loss + # Multiply loss by number of tokens in the batch + # The loss is averaged over the number of non-masked tokens + # To get the total loss, multiply by the number of non-masked tokens + total_loss += loss.item() * torch.sum(labels != -100).item() + total_tokens += torch.sum(labels != -100).item() + + # Calculate perplexity + perplexity = math.exp(total_loss / total_tokens) + return perplexity \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..8a63deb98bbf0594169fad09cbe09b85ba01a32b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Susav Shrestha + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 7be5fc7f47d5db027d120b8024982df93db95b74..04a225f5700b4073c925e1cabd9dd033b96fc67a 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,280 @@ ---- -license: mit ---- +# Polar Sparsity: High Throughput Batched LLM Inferencing with Scalable Contextual Sparsity + +Polar Sparsity is a framework for efficient sparse inferencing in large language models (LLMs), leveraging custom Triton kernels and learned routers for selective activation of MLP neurons and attention heads. This repository provides tools for data collection, router training, benchmarking, and end-to-end sparse generation. + +--- + +## ⚠️ Requirements + +- Python 3.8+ +- [PyTorch](https://pytorch.org/) (tested on >=1.13) +- [Transformers](https://github.com/huggingface/transformers) (tested on >=4.30) +- See [`environment.yml`](environment.yml) for all dependencies. + +> **Note:** Some scripts may require additional dependencies (e.g., `matplotlib`, `pandas`). + +--- + +## 🗂️ Model Indices + +The following table lists common model indices used in `--model_index` (see also `HybridTensor/utils/activations.py`): + +| Index | Model Name | +|-------|-----------------------------------------| +| 5 | facebook/opt-6.7b | +| 8 | facebook/opt-66b | +| 11 | meta-llama/Llama-2-7b-hf | +| 15 | meta-llama/Llama-3.1-70B | + +--- + +## 📦 Repository Structure + +- **Router Data Collection & Training** + - Data Collection: [`HybridTensor/routers/datacollection/data_collection.py`](HybridTensor/routers/datacollection/data_collection.py) + - MLP Router Training: [`HybridTensor/routers/mlp/main_mlp.py`](HybridTensor/routers/mlp/main_mlp.py) + - MHA Router Training: [`HybridTensor/routers/mha/main_att.py`](HybridTensor/routers/mha/main_att.py) +- **Benchmarks** + - Evaluation: [`HybridTensor/benchmarks/model_eval.py`](HybridTensor/benchmarks/model_eval.py) +- **Kernel Implementations** + - Triton Kernels: [`HybridTensor/triton/`](HybridTensor/triton/) + - Example Runners: [`run_sparse_mlp.py`](run_sparse_mlp.py), [`run_sparse_attn.py`](run_sparse_attn.py), [`run_sparse_transformer_block.py`](run_sparse_transformer_block.py) +- **Sparse Generation** + - End-to-End Sparse Generation: [`model_sparse_generation.py`](model_sparse_generation.py) + +--- + +## 🚀 Getting Started + +### 1. Environment Setup + +- Install dependencies (see [`environment.yml`](environment.yml) for details). + +```bash +conda env create -f environment.yml +``` + +- For Triton kernels, install the latest nightly build: + ```bash + pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly + ``` + +--- + +### 2. Router Data Collection + +To collect router data for a specific model, you can use: + +```bash +python -m HybridTensor.routers.datacollection.data_collection \ + --model_index 5 \ + --batch_size 8 \ + --device_map auto \ + --data_dir \ + --max_samples 400000 \ + --model_family \ + --mlp_activation True \ + --attn_norm True +``` + +**Argument explanations:** + +- `--model_index`: Index of the model to use (see `HybridTensor/utils/activations.py` for available indices). +- `--batch_size`: Number of samples per batch during data collection, adjust to configure GPU memory usage. +- `--data_dir`: Directory to save the collected activation data. +- `--model_family`: Model family (e.g., `opt`, `llama`). +- `--mlp_activation`: Set to `True` to collect MLP activation data. Only for sparse MLP models. +- `--attn_norm`: Set to `True` to collect attention norm data. + +--- + +### 3. Router Training and Optimizations + +**MLP Router:** + +To run the MLP router training use the following scripts + +For a single layer: + +```bash +python -m HybridTensor.routers.mlp.main_mlp \ + --model_index \ + --L \ + --data_dir \ + --ckpt_dir \ + --gpu +``` + +For all layers, edit the [`HybridTensor/routers/mlp/train_mlp_routers.sh'](HybridTensor/routers/mlp/train_mlp_routers.sh) file with the number of GPUs available, model index, total number of layers, data_dir and ckpt_dir. + +```bash +./HybridTensor/routers/mlp/train_mlp_routers.sh +``` + +**MHA Router:** + +To run the attention router training use the following scripts + +For a single layer: + +```bash +python -m HybridTensor.routers.mha.main_att \ + --model_index \ + --L \ + --k \ + --data_dir \ + --ckpt_dir +``` + +For all layers, edit the [`HybridTensor/routers/mha/train_mha_routers_topk.sh'](HybridTensor/routers/mha/train_mha_routers_topk.sh) file with the number of GPUs available, model index, total number of layers, data_dir and ckpt_dir. + +```bash +./HybridTensor/routers/mha/train_mha_routers_topk.sh +``` +To optimize the MLP layers for ReLU model with our dynamic layer wise top-k algorithm, you can use: + + +```bash +python -m HybridTensor.routers.mlp.mlp_router_optim_fast --model_index --batch_size --mlp_ckpt_dir --act_data_dir + +``` +- `--batch_size`: batch size to optimize for inference + +--- + +### 4. Model Evaluation + +You can evaluate your models on various benchmarks using the [`HybridTensor/benchmarks/model_eval.py`](/HybridTensor/benchmarks/model_eval.py) script. Below are example commands and explanations for the main arguments. These scripts use huggingface implementations with masking for easy benchmarking. These do not use the optimized kernels for efficient inference. + +**Example usage:** + +```bash +python -m HybridTensor.benchmarks.model_eval \ + --model_index \ + --batch_size \ + --mode \ + --benchmark \ + --attn_topk \ + --attn_ckpt_dir \ + --mlp_ckpt_dir \ + --data_collection \ + --device auto \ + --note +``` + +**Additional argument explanations:** + +- `--batch_size`: Batch size to use for evaluation. +- `--mode`: Evaluation mode. Options are `dense` (standard), `sparse` (sparse MLP and/or attention using trained routers), or `sparse_attn` (sparse attention only using ground truth activations ,doesn't require routers). +- `--benchmark`: Which benchmark(s) to run. Use `all` for the full suite or specify a single benchmark (e.g., `mmlu`). +- `--attn_topk`: Top-k value for attention sparsity (e.g., 0.5 for 50% sparsity). +- `--attn_ckpt_dir`: Directory containing attention router checkpoints. +- `--mlp_ckpt_dir`: Directory containing MLP router checkpoints. +- `--data_collection`: Set to `True` to enable data collection mode for threshold sweeps. +- `--device`: Device ID to use (e.g., `0` for `cuda:0`). +- `--note`: Optional note to append to the results filename. + + +Adjust the arguments as needed for your experiment or hardware setup. + +--- + +### 5. Kernel Implementations + +**Triton Kernels:** Custom kernels for selective MLP and attention are in [`HybridTensor/triton/`](HybridTensor/triton/). + +Benchmark the speedup of the selective GEMM kernel (used for sparse MLPs): + +```bash +python -m HybridTensor.triton.gather_gemm_col \ + --batch_size \ + --in_features \ + --index_size +``` + +- `--in_features`: Model embedding dimension (e.g., 8192). +- `--index_size`: Total number of active neurons selected by the router. Needs to be less than or equal to total neurons. + +--- + +Benchmark the speedup for a sparse MLP layer: + +```bash +python run_sparse_mlp.py \ + --in_features \ + --batch_size \ + --index_size +``` + +Benchmark the speedup for a sparse Multi-Head Attention (MHA) layer: + +--- + + +```bash +python run_sparse_attn.py \ + --in_features \ + --batch_size \ + --seq_len \ + --attn_topk +``` + +- `--attn_topk`: Fraction of attention heads to keep active (e.g., 0.5 for 50%). + +--- + +Use the following script before running autotune_configs.py + +``` bash +export TRITON_PRINT_AUTOTUNING="1" +``` + +For models with sparse MLP, use the [`HybridTensor/triton/heuristics/autotune_configs.py`](HybridTensor/triton/heuristics/autotune_configs.py) script to compile the kernels for different batch sizes and activation to speedup inference. + +Benchmark the speedup for a full sparse transformer block with different batch sizes and sequence lengths: + +```bash +python run_sparse_transformer_block.py \ + --in_features \ + --batch_size \ + --seq_len \ + --index_size \ + --attn_topk +``` + +> **Note:** +> The `run_sparse_transformer_block.py` script can also be used to simulate large-scale inferencing setups with large batch sizes and sequence lengths on a single GPU if multi-GPU system is not available, since only a single transformer layer is executed in this script. + + +### 6. Sparse Generation + +Run end-to-end sparse generation using trained routers. This example shows how to build the sparse model for end-to-end generation using the optimized kernels and batched inference. + +```bash +python -m HybridTensor.benchmarks.generation.model_sparse_generation \ + --model_index \ + --mlp_ckpt_dir \ + --attn_ckpt_dir \ + --batch_stats_dir \ + --attn_topk +``` +- `--batch_stats_dir`: used for sparse MLP models, path to the output from dynamic top-k optimization. Saved in configs/ + + +--- + +## Citation + +If you find our work helpful, please cite us: + +```bibtex +@misc{shrestha2025polarsparsityhighthroughput, + title={Polar Sparsity: High Throughput Batched LLM Inferencing with Scalable Contextual Sparsity}, + author={Susav Shrestha and Brad Settlemyer and Nikoli Dryden and Narasimha Reddy}, + year={2025}, + eprint={2505.14884}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2505.14884}, +} +``` \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/mlp_router/opt-30b/mlp_router_configs_bsize_1.csv b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_1.csv new file mode 100644 index 0000000000000000000000000000000000000000..bfe452a592f9d79867c9bac59af15fc53397cb8c --- /dev/null +++ b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_1.csv @@ -0,0 +1,49 @@ +layer,batch_size,optimal_k +0,1,512 +1,1,512 +2,1,512 +3,1,512 +4,1,512 +5,1,512 +6,1,512 +7,1,512 +8,1,512 +9,1,512 +10,1,512 +11,1,512 +12,1,512 +13,1,512 +14,1,512 +15,1,512 +16,1,640 +17,1,512 +18,1,640 +19,1,512 +20,1,640 +21,1,896 +22,1,1152 +23,1,1408 +24,1,1792 +25,1,2048 +26,1,2560 +27,1,2816 +28,1,3072 +29,1,3328 +30,1,3712 +31,1,3968 +32,1,4096 +33,1,4096 +34,1,4096 +35,1,4096 +36,1,4096 +37,1,4096 +38,1,4096 +39,1,4096 +40,1,4096 +41,1,4096 +42,1,4096 +43,1,4096 +44,1,4096 +45,1,4096 +46,1,4096 +47,1,4096 diff --git a/configs/mlp_router/opt-30b/mlp_router_configs_bsize_128.csv b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_128.csv new file mode 100644 index 0000000000000000000000000000000000000000..9fc935159d3066437fe3fd7c9e2abb9ad9c4019d --- /dev/null +++ b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_128.csv @@ -0,0 +1,49 @@ +layer,batch_size,optimal_k +0,128,2304 +1,128,2560 +2,128,2048 +3,128,1664 +4,128,1536 +5,128,1664 +6,128,1792 +7,128,2048 +8,128,2176 +9,128,2176 +10,128,2816 +11,128,3840 +12,128,5760 +13,128,7168 +14,128,8576 +15,128,10752 +16,128,12160 +17,128,12416 +18,128,13824 +19,128,14592 +20,128,16640 +21,128,18816 +22,128,20992 +23,128,22528 +24,128,23808 +25,128,24960 +26,128,26112 +27,128,26624 +28,128,26880 +29,128,27008 +30,128,27264 +31,128,27264 +32,128,27392 +33,128,27520 +34,128,27648 +35,128,27520 +36,128,27520 +37,128,27520 +38,128,27776 +39,128,27648 +40,128,27776 +41,128,27904 +42,128,28032 +43,128,28160 +44,128,28160 +45,128,28160 +46,128,28032 +47,128,27008 diff --git a/configs/mlp_router/opt-30b/mlp_router_configs_bsize_16.csv b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_16.csv new file mode 100644 index 0000000000000000000000000000000000000000..a95f2458b1a1e17c5690319f7e58662aa57304cc --- /dev/null +++ b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_16.csv @@ -0,0 +1,49 @@ +layer,batch_size,optimal_k +0,16,768 +1,16,1024 +2,16,896 +3,16,512 +4,16,512 +5,16,512 +6,16,512 +7,16,640 +8,16,640 +9,16,640 +10,16,768 +11,16,896 +12,16,1408 +13,16,1920 +14,16,2176 +15,16,2944 +16,16,3456 +17,16,3456 +18,16,4352 +19,16,4480 +20,16,5504 +21,16,6272 +22,16,7040 +23,16,7808 +24,16,8960 +25,16,9984 +26,16,11264 +27,16,11648 +28,16,12160 +29,16,12416 +30,16,13312 +31,16,13696 +32,16,14464 +33,16,15104 +34,16,15232 +35,16,14976 +36,16,15104 +37,16,15232 +38,16,15488 +39,16,15744 +40,16,16256 +41,16,16896 +42,16,17664 +43,16,18432 +44,16,19328 +45,16,19840 +46,16,19712 +47,16,15872 diff --git a/configs/mlp_router/opt-30b/mlp_router_configs_bsize_2.csv b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_2.csv new file mode 100644 index 0000000000000000000000000000000000000000..b3ab42529dec86aef1c030249bf66fd0c416b4ba --- /dev/null +++ b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_2.csv @@ -0,0 +1,49 @@ +layer,batch_size,optimal_k +0,2,384 +1,2,384 +2,2,384 +3,2,384 +4,2,384 +5,2,384 +6,2,384 +7,2,384 +8,2,384 +9,2,384 +10,2,384 +11,2,384 +12,2,512 +13,2,512 +14,2,512 +15,2,768 +16,2,896 +17,2,896 +18,2,1024 +19,2,1024 +20,2,1280 +21,2,1408 +22,2,1920 +23,2,2432 +24,2,3328 +25,2,3584 +26,2,3840 +27,2,3968 +28,2,4096 +29,2,4224 +30,2,4480 +31,2,4608 +32,2,4864 +33,2,4992 +34,2,4992 +35,2,4864 +36,2,4864 +37,2,4864 +38,2,4864 +39,2,4864 +40,2,4992 +41,2,5248 +42,2,5504 +43,2,5760 +44,2,6016 +45,2,6400 +46,2,6400 +47,2,5376 diff --git a/configs/mlp_router/opt-30b/mlp_router_configs_bsize_256.csv b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_256.csv new file mode 100644 index 0000000000000000000000000000000000000000..71cb8a14ee15ab89cb29b122d6878be45f1106cb --- /dev/null +++ b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_256.csv @@ -0,0 +1,49 @@ +layer,batch_size,optimal_k +0,256,3200 +1,256,3584 +2,256,2944 +3,256,2560 +4,256,2432 +5,256,2304 +6,256,2688 +7,256,2944 +8,256,3328 +9,256,3712 +10,256,4352 +11,256,6016 +12,256,8704 +13,256,10624 +14,256,12288 +15,256,14848 +16,256,16384 +17,256,16640 +18,256,18432 +19,256,19456 +20,256,21888 +21,256,24192 +22,256,25984 +23,256,26880 +24,256,27136 +25,256,27392 +26,256,27520 +27,256,27904 +28,256,27776 +29,256,27904 +30,256,28032 +31,256,27904 +32,256,28160 +33,256,28032 +34,256,28032 +35,256,28032 +36,256,28160 +37,256,28032 +38,256,28160 +39,256,28288 +40,256,28288 +41,256,28416 +42,256,28544 +43,256,28544 +44,256,28544 +45,256,28544 +46,256,28544 +47,256,27776 diff --git a/configs/mlp_router/opt-30b/mlp_router_configs_bsize_32.csv b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_32.csv new file mode 100644 index 0000000000000000000000000000000000000000..b9a048748d9d1bb8c8ec48de3f3099f8096e123f --- /dev/null +++ b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_32.csv @@ -0,0 +1,49 @@ +layer,batch_size,optimal_k +0,32,1152 +1,32,1280 +2,32,1152 +3,32,896 +4,32,640 +5,32,640 +6,32,768 +7,32,768 +8,32,896 +9,32,896 +10,32,1152 +11,32,1664 +12,32,2304 +13,32,2816 +14,32,3584 +15,32,4736 +16,32,5632 +17,32,5760 +18,32,6912 +19,32,7168 +20,32,8064 +21,32,9216 +22,32,10496 +23,32,11648 +24,32,13056 +25,32,14464 +26,32,16000 +27,32,16512 +28,32,17024 +29,32,17536 +30,32,18432 +31,32,18816 +32,32,19712 +33,32,20480 +34,32,20608 +35,32,20352 +36,32,20736 +37,32,20736 +38,32,21120 +39,32,21504 +40,32,22016 +41,32,22656 +42,32,23424 +43,32,24064 +44,32,24832 +45,32,25344 +46,32,25216 +47,32,21120 diff --git a/configs/mlp_router/opt-30b/mlp_router_configs_bsize_4.csv b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_4.csv new file mode 100644 index 0000000000000000000000000000000000000000..d97f86b46904675d8b7af4186e9a8a4d17453b71 --- /dev/null +++ b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_4.csv @@ -0,0 +1,49 @@ +layer,batch_size,optimal_k +0,4,640 +1,4,640 +2,4,384 +3,4,384 +4,4,384 +5,4,384 +6,4,384 +7,4,384 +8,4,384 +9,4,384 +10,4,512 +11,4,512 +12,4,640 +13,4,640 +14,4,896 +15,4,1024 +16,4,1408 +17,4,1408 +18,4,1664 +19,4,1536 +20,4,1920 +21,4,2304 +22,4,3200 +23,4,3968 +24,4,4352 +25,4,4736 +26,4,5248 +27,4,5504 +28,4,5760 +29,4,5888 +30,4,6400 +31,4,6528 +32,4,6912 +33,4,7168 +34,4,7168 +35,4,6912 +36,4,6912 +37,4,6912 +38,4,7040 +39,4,7168 +40,4,7296 +41,4,7680 +42,4,8064 +43,4,8448 +44,4,9088 +45,4,9472 +46,4,9472 +47,4,7552 diff --git a/configs/mlp_router/opt-30b/mlp_router_configs_bsize_48.csv b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_48.csv new file mode 100644 index 0000000000000000000000000000000000000000..04283b6e4692624a2e45d298c7c8100696e9c424 --- /dev/null +++ b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_48.csv @@ -0,0 +1,49 @@ +layer,batch_size,optimal_k +0,48,1280 +1,48,1536 +2,48,1280 +3,48,1152 +4,48,1024 +5,48,1024 +6,48,896 +7,48,1024 +8,48,1152 +9,48,1152 +10,48,1664 +11,48,2048 +12,48,2944 +13,48,3840 +14,48,4736 +15,48,6144 +16,48,7168 +17,48,7424 +18,48,8576 +19,48,8832 +20,48,9984 +21,48,11520 +22,48,13184 +23,48,14464 +24,48,16128 +25,48,17664 +26,48,19200 +27,48,19840 +28,48,20224 +29,48,20608 +30,48,21504 +31,48,21888 +32,48,22784 +33,48,23424 +34,48,23680 +35,48,23424 +36,48,23808 +37,48,23808 +38,48,24192 +39,48,24448 +40,48,24960 +41,48,25472 +42,48,26240 +43,48,26752 +44,48,27008 +45,48,27264 +46,48,27008 +47,48,23936 diff --git a/configs/mlp_router/opt-30b/mlp_router_configs_bsize_64.csv b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_64.csv new file mode 100644 index 0000000000000000000000000000000000000000..9bafb2153518496ac4d328f032d1d0e247b5c9a0 --- /dev/null +++ b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_64.csv @@ -0,0 +1,49 @@ +layer,batch_size,optimal_k +0,64,1536 +1,64,1792 +2,64,1408 +3,64,1280 +4,64,1152 +5,64,1152 +6,64,1280 +7,64,1152 +8,64,1280 +9,64,1408 +10,64,1920 +11,64,2432 +12,64,3712 +13,64,4736 +14,64,5760 +15,64,7296 +16,64,8576 +17,64,8704 +18,64,9856 +19,64,10368 +20,64,11776 +21,64,13440 +22,64,15360 +23,64,16768 +24,64,18432 +25,64,19840 +26,64,21376 +27,64,22016 +28,64,22400 +29,64,22784 +30,64,23552 +31,64,23936 +32,64,24704 +33,64,25344 +34,64,25472 +35,64,25344 +36,64,25728 +37,64,25728 +38,64,25984 +39,64,26368 +40,64,26752 +41,64,26880 +42,64,27264 +43,64,27264 +44,64,27520 +45,64,27520 +46,64,27392 +47,64,25728 diff --git a/configs/mlp_router/opt-30b/mlp_router_configs_bsize_8.csv b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_8.csv new file mode 100644 index 0000000000000000000000000000000000000000..96ac68ab2a1cf2eeeb8a6ad65acebeb7ca1c408d --- /dev/null +++ b/configs/mlp_router/opt-30b/mlp_router_configs_bsize_8.csv @@ -0,0 +1,49 @@ +layer,batch_size,optimal_k +0,8,640 +1,8,768 +2,8,768 +3,8,384 +4,8,384 +5,8,384 +6,8,512 +7,8,512 +8,8,512 +9,8,512 +10,8,640 +11,8,640 +12,8,768 +13,8,1152 +14,8,1280 +15,8,1792 +16,8,2176 +17,8,2304 +18,8,2688 +19,8,2688 +20,8,3200 +21,8,4224 +22,8,4992 +23,8,5376 +24,8,6144 +25,8,6784 +26,8,7680 +27,8,7936 +28,8,8320 +29,8,8576 +30,8,9216 +31,8,9472 +32,8,10112 +33,8,10496 +34,8,10496 +35,8,10240 +36,8,10368 +37,8,10368 +38,8,10496 +39,8,10752 +40,8,11136 +41,8,11520 +42,8,12160 +43,8,12800 +44,8,13568 +45,8,14080 +46,8,14080 +47,8,11136 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_1.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_1.csv new file mode 100644 index 0000000000000000000000000000000000000000..0e2161182ce668362b7310c68e03e80037e87ca7 --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_1.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,1,896 +1,1,512 +2,1,512 +3,1,512 +4,1,512 +5,1,512 +6,1,512 +7,1,512 +8,1,512 +9,1,512 +10,1,512 +11,1,512 +12,1,640 +13,1,896 +14,1,1024 +15,1,1152 +16,1,1536 +17,1,1792 +18,1,2048 +19,1,2304 +20,1,2304 +21,1,2432 +22,1,2304 +23,1,2304 +24,1,2304 +25,1,2304 +26,1,2432 +27,1,2688 +28,1,2944 +29,1,3200 +30,1,3712 +31,1,3200 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_128.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_128.csv new file mode 100644 index 0000000000000000000000000000000000000000..64070229d2653aea727e50048effc66a3cf0db82 --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_128.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,128,7424 +1,128,2176 +2,128,2304 +3,128,2176 +4,128,2304 +5,128,3072 +6,128,4864 +7,128,6016 +8,128,7168 +9,128,7936 +10,128,9472 +11,128,10880 +12,128,12032 +13,128,13568 +14,128,14080 +15,128,14592 +16,128,14976 +17,128,15104 +18,128,15360 +19,128,15488 +20,128,15488 +21,128,15616 +22,128,15616 +23,128,15744 +24,128,15744 +25,128,15872 +26,128,15872 +27,128,16000 +28,128,16000 +29,128,16128 +30,128,16128 +31,128,15616 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_16.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_16.csv new file mode 100644 index 0000000000000000000000000000000000000000..2ec4eebd768552fcc546215c29ab86d3b498a914 --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_16.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,16,3584 +1,16,640 +2,16,640 +3,16,640 +4,16,640 +5,16,896 +6,16,1408 +7,16,1664 +8,16,2176 +9,16,2560 +10,16,3200 +11,16,3712 +12,16,4480 +13,16,5760 +14,16,6144 +15,16,6528 +16,16,7552 +17,16,8192 +18,16,8576 +19,16,9088 +20,16,9344 +21,16,9600 +22,16,9728 +23,16,9856 +24,16,10112 +25,16,10368 +26,16,10880 +27,16,11392 +28,16,12160 +29,16,12800 +30,16,13184 +31,16,11008 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_1_{nominmax}.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_1_{nominmax}.csv new file mode 100644 index 0000000000000000000000000000000000000000..04410ce77932050c86cfc59d5dd671f84cc8c3d9 --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_1_{nominmax}.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,1,896 +1,1,256 +2,1,256 +3,1,256 +4,1,256 +5,1,256 +6,1,256 +7,1,256 +8,1,384 +9,1,512 +10,1,512 +11,1,512 +12,1,640 +13,1,896 +14,1,1024 +15,1,1152 +16,1,1536 +17,1,1792 +18,1,2048 +19,1,2304 +20,1,2304 +21,1,2432 +22,1,2304 +23,1,2304 +24,1,2304 +25,1,2304 +26,1,2432 +27,1,2688 +28,1,2944 +29,1,3200 +30,1,3712 +31,1,3200 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_2.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_2.csv new file mode 100644 index 0000000000000000000000000000000000000000..ef301ca3fdc4b1dde49cc6cb1e3fe1f31753b0c4 --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_2.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,2,2560 +1,2,384 +2,2,384 +3,2,384 +4,2,384 +5,2,384 +6,2,384 +7,2,512 +8,2,512 +9,2,640 +10,2,640 +11,2,896 +12,2,1024 +13,2,1536 +14,2,1920 +15,2,2176 +16,2,3328 +17,2,3584 +18,2,3712 +19,2,3840 +20,2,3968 +21,2,3968 +22,2,3968 +23,2,3968 +24,2,3968 +25,2,3968 +26,2,4096 +27,2,4224 +28,2,4480 +29,2,4736 +30,2,5120 +31,2,4352 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_256.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_256.csv new file mode 100644 index 0000000000000000000000000000000000000000..dc3a54cc7815c95514377560a28c1867090a9a42 --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_256.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,256,9344 +1,256,3200 +2,256,3584 +3,256,3584 +4,256,3712 +5,256,4736 +6,256,6784 +7,256,8192 +8,256,9216 +9,256,10240 +10,256,11776 +11,256,13184 +12,256,14208 +13,256,15104 +14,256,15360 +15,256,15616 +16,256,15744 +17,256,15872 +18,256,15872 +19,256,15872 +20,256,16000 +21,256,16000 +22,256,16128 +23,256,16000 +24,256,16000 +25,256,16128 +26,256,16256 +27,256,16256 +28,256,16256 +29,256,16384 +30,256,16384 +31,256,16000 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_32.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_32.csv new file mode 100644 index 0000000000000000000000000000000000000000..0feca1eef2995a3d21a70983a2a566ebd404dbb3 --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_32.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,32,4480 +1,32,896 +2,32,896 +3,32,896 +4,32,1024 +5,32,1408 +6,32,1920 +7,32,2560 +8,32,3328 +9,32,3840 +10,32,4736 +11,32,5632 +12,32,6784 +13,32,7808 +14,32,8448 +15,32,8832 +16,32,10112 +17,32,10752 +18,32,11136 +19,32,11776 +20,32,12032 +21,32,12288 +22,32,12672 +23,32,12928 +24,32,13184 +25,32,13440 +26,32,14080 +27,32,14464 +28,32,14720 +29,32,15104 +30,32,15104 +31,32,13824 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_4.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_4.csv new file mode 100644 index 0000000000000000000000000000000000000000..096ff4e2844d0c2b1708a37b97783c70504d8b22 --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_4.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,4,2816 +1,4,384 +2,4,384 +3,4,384 +4,4,384 +5,4,512 +6,4,512 +7,4,640 +8,4,1024 +9,4,1024 +10,4,1152 +11,4,1408 +12,4,1792 +13,4,2304 +14,4,2944 +15,4,3200 +16,4,4224 +17,4,4480 +18,4,4736 +19,4,4992 +20,4,5120 +21,4,5248 +22,4,5248 +23,4,5248 +24,4,5248 +25,4,5376 +26,4,5632 +27,4,5888 +28,4,6272 +29,4,6656 +30,4,7040 +31,4,5888 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_48.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_48.csv new file mode 100644 index 0000000000000000000000000000000000000000..e3b842315d90f6749bf2ec9571adf60961f01916 --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_48.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,48,5248 +1,48,1024 +2,48,1152 +3,48,1152 +4,48,1280 +5,48,1792 +6,48,2688 +7,48,3456 +8,48,4352 +9,48,4864 +10,48,6016 +11,48,7040 +12,48,8192 +13,48,9344 +14,48,10112 +15,48,10496 +16,48,11776 +17,48,12416 +18,48,12800 +19,48,13312 +20,48,13696 +21,48,13952 +22,48,14208 +23,48,14464 +24,48,14592 +25,48,14592 +26,48,14976 +27,48,15104 +28,48,15488 +29,48,15616 +30,48,15616 +31,48,14464 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_64.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_64.csv new file mode 100644 index 0000000000000000000000000000000000000000..4d653398729b09aaf53d0b6838d8ce98350de9ca --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_64.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,64,5760 +1,64,1280 +2,64,1408 +3,64,1408 +4,64,1408 +5,64,2048 +6,64,3200 +7,64,3968 +8,64,4992 +9,64,5632 +10,64,7040 +11,64,8064 +12,64,9344 +13,64,10624 +14,64,11392 +15,64,11648 +16,64,13056 +17,64,13568 +18,64,13952 +19,64,14336 +20,64,14720 +21,64,14720 +22,64,14848 +23,64,14848 +24,64,14976 +25,64,15104 +26,64,15360 +27,64,15616 +28,64,15744 +29,64,15872 +30,64,15744 +31,64,14976 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_8.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_8.csv new file mode 100644 index 0000000000000000000000000000000000000000..4c198c40af929111a5c31b5531124dc8a432e021 --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_8.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,8,3072 +1,8,512 +2,8,512 +3,8,512 +4,8,512 +5,8,640 +6,8,768 +7,8,1152 +8,8,1408 +9,8,1664 +10,8,1920 +11,8,2304 +12,8,2816 +13,8,3840 +14,8,4480 +15,8,4864 +16,8,5632 +17,8,6016 +18,8,6400 +19,8,6784 +20,8,6912 +21,8,7040 +22,8,7168 +23,8,7296 +24,8,7296 +25,8,7424 +26,8,7936 +27,8,8320 +28,8,8960 +29,8,9472 +30,8,9984 +31,8,8192 diff --git a/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_8_{0.99th_nolimit}.csv b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_8_{0.99th_nolimit}.csv new file mode 100644 index 0000000000000000000000000000000000000000..0b2bfd6d7b97548974bd4587d68b23cb8c3de6ad --- /dev/null +++ b/configs/mlp_router/opt-6.7b/mlp_router_configs_bsize_8_{0.99th_nolimit}.csv @@ -0,0 +1,33 @@ +layer,batch_size,optimal_k +0,8,7680 +1,8,384 +2,8,384 +3,8,384 +4,8,384 +5,8,512 +6,8,768 +7,8,1024 +8,8,1408 +9,8,1536 +10,8,1920 +11,8,2176 +12,8,2816 +13,8,3712 +14,8,4480 +15,8,4992 +16,8,8192 +17,8,9600 +18,8,10624 +19,8,11648 +20,8,11648 +21,8,11904 +22,8,11648 +23,8,11648 +24,8,11392 +25,8,11648 +26,8,12160 +27,8,12288 +28,8,12800 +29,8,13312 +30,8,13696 +31,8,12288 diff --git a/configs/mlp_router/opt-66b/mlp_router_configs_bsize_1.csv b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_1.csv new file mode 100644 index 0000000000000000000000000000000000000000..46a95f4d9ad2bf7a698d26579a3eb763ef9d06c6 --- /dev/null +++ b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_1.csv @@ -0,0 +1,65 @@ +layer,batch_size,optimal_k +0,1,384 +1,1,256 +2,1,256 +3,1,256 +4,1,256 +5,1,256 +6,1,256 +7,1,256 +8,1,256 +9,1,256 +10,1,256 +11,1,256 +12,1,256 +13,1,256 +14,1,256 +15,1,256 +16,1,256 +17,1,256 +18,1,256 +19,1,384 +20,1,512 +21,1,512 +22,1,512 +23,1,512 +24,1,512 +25,1,512 +26,1,640 +27,1,896 +28,1,1024 +29,1,1152 +30,1,1536 +31,1,1792 +32,1,2176 +33,1,2560 +34,1,2688 +35,1,2944 +36,1,3584 +37,1,3712 +38,1,4096 +39,1,4608 +40,1,4608 +41,1,4736 +42,1,4864 +43,1,5248 +44,1,5376 +45,1,5760 +46,1,5376 +47,1,5504 +48,1,5248 +49,1,4992 +50,1,4864 +51,1,4864 +52,1,5248 +53,1,4864 +54,1,4992 +55,1,5248 +56,1,5376 +57,1,5632 +58,1,6144 +59,1,6528 +60,1,7168 +61,1,7680 +62,1,8064 +63,1,6912 diff --git a/configs/mlp_router/opt-66b/mlp_router_configs_bsize_128.csv b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_128.csv new file mode 100644 index 0000000000000000000000000000000000000000..3ca00458e57d7aad2f878fe2ce4c3431feb043b8 --- /dev/null +++ b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_128.csv @@ -0,0 +1,65 @@ +layer,batch_size,optimal_k +0,128,2176 +1,128,2432 +2,128,2304 +3,128,1920 +4,128,1536 +5,128,1408 +6,128,1280 +7,128,1280 +8,128,1024 +9,128,1408 +10,128,1280 +11,128,1280 +12,128,1408 +13,128,1664 +14,128,2304 +15,128,3072 +16,128,4992 +17,128,5120 +18,128,7168 +19,128,9344 +20,128,11136 +21,128,10624 +22,128,11264 +23,128,11776 +24,128,12928 +25,128,15744 +26,128,18048 +27,128,20864 +28,128,23040 +29,128,24576 +30,128,25856 +31,128,27520 +32,128,28800 +33,128,30464 +34,128,31104 +35,128,32000 +36,128,32640 +37,128,33280 +38,128,33664 +39,128,34304 +40,128,34304 +41,128,34560 +42,128,34688 +43,128,35200 +44,128,35328 +45,128,35456 +46,128,35328 +47,128,35200 +48,128,35200 +49,128,35328 +50,128,35328 +51,128,35328 +52,128,35328 +53,128,35584 +54,128,35456 +55,128,35456 +56,128,35712 +57,128,35712 +58,128,35840 +59,128,35712 +60,128,35968 +61,128,35840 +62,128,35840 +63,128,34688 diff --git a/configs/mlp_router/opt-66b/mlp_router_configs_bsize_16.csv b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_16.csv new file mode 100644 index 0000000000000000000000000000000000000000..367f85ebad32a2e77958e9720d71fad89da0c2aa --- /dev/null +++ b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_16.csv @@ -0,0 +1,65 @@ +layer,batch_size,optimal_k +0,16,896 +1,16,768 +2,16,640 +3,16,640 +4,16,512 +5,16,512 +6,16,512 +7,16,512 +8,16,384 +9,16,512 +10,16,512 +11,16,512 +12,16,512 +13,16,512 +14,16,640 +15,16,768 +16,16,1280 +17,16,1408 +18,16,1920 +19,16,2560 +20,16,3200 +21,16,2816 +22,16,3200 +23,16,3072 +24,16,3584 +25,16,4608 +26,16,5760 +27,16,6528 +28,16,7168 +29,16,7680 +30,16,8320 +31,16,9216 +32,16,9984 +33,16,11264 +34,16,11520 +35,16,12288 +36,16,13312 +37,16,14080 +38,16,14592 +39,16,15360 +40,16,15232 +41,16,15616 +42,16,15744 +43,16,16512 +44,16,16768 +45,16,17152 +46,16,17024 +47,16,17152 +48,16,17024 +49,16,17024 +50,16,17024 +51,16,17152 +52,16,17536 +53,16,17664 +54,16,17920 +55,16,18560 +56,16,19072 +57,16,19584 +58,16,20608 +59,16,21120 +60,16,21760 +61,16,22144 +62,16,21888 +63,16,18560 diff --git a/configs/mlp_router/opt-66b/mlp_router_configs_bsize_2.csv b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_2.csv new file mode 100644 index 0000000000000000000000000000000000000000..4c951f310cc9f73a08f13b0c474c04e7fae66626 --- /dev/null +++ b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_2.csv @@ -0,0 +1,65 @@ +layer,batch_size,optimal_k +0,2,384 +1,2,512 +2,2,384 +3,2,384 +4,2,384 +5,2,384 +6,2,256 +7,2,384 +8,2,256 +9,2,384 +10,2,256 +11,2,256 +12,2,384 +13,2,384 +14,2,384 +15,2,384 +16,2,384 +17,2,384 +18,2,512 +19,2,512 +20,2,896 +21,2,896 +22,2,896 +23,2,640 +24,2,896 +25,2,1024 +26,2,1024 +27,2,1408 +28,2,1664 +29,2,2048 +30,2,2560 +31,2,3072 +32,2,3456 +33,2,3840 +34,2,3840 +35,2,3968 +36,2,4224 +37,2,4480 +38,2,4608 +39,2,4864 +40,2,4736 +41,2,4864 +42,2,4864 +43,2,4992 +44,2,5120 +45,2,5248 +46,2,5120 +47,2,5120 +48,2,5120 +49,2,4992 +50,2,4992 +51,2,4992 +52,2,5120 +53,2,5120 +54,2,5120 +55,2,5376 +56,2,5504 +57,2,5632 +58,2,5888 +59,2,6144 +60,2,6400 +61,2,6656 +62,2,6784 +63,2,5888 diff --git a/configs/mlp_router/opt-66b/mlp_router_configs_bsize_256.csv b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_256.csv new file mode 100644 index 0000000000000000000000000000000000000000..9c103cf03185389815c2edc73794adeacdb89d1c --- /dev/null +++ b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_256.csv @@ -0,0 +1,65 @@ +layer,batch_size,optimal_k +0,256,3072 +1,256,3584 +2,256,3328 +3,256,2688 +4,256,2304 +5,256,2304 +6,256,1792 +7,256,1792 +8,256,1792 +9,256,2048 +10,256,2048 +11,256,1920 +12,256,2432 +13,256,2816 +14,256,3456 +15,256,4736 +16,256,7168 +17,256,7808 +18,256,10624 +19,256,13056 +20,256,15232 +21,256,14720 +22,256,15488 +23,256,16128 +24,256,17920 +25,256,21376 +26,256,24576 +27,256,27648 +28,256,29952 +29,256,31232 +30,256,32384 +31,256,33664 +32,256,34432 +33,256,34944 +34,256,35456 +35,256,35584 +36,256,35584 +37,256,35712 +38,256,35712 +39,256,35968 +40,256,35968 +41,256,35840 +42,256,35968 +43,256,35968 +44,256,35968 +45,256,36096 +46,256,35968 +47,256,36096 +48,256,36096 +49,256,36096 +50,256,36096 +51,256,36224 +52,256,36224 +53,256,36096 +54,256,36096 +55,256,36224 +56,256,36224 +57,256,36352 +58,256,36480 +59,256,36480 +60,256,36480 +61,256,36480 +62,256,36480 +63,256,35712 diff --git a/configs/mlp_router/opt-66b/mlp_router_configs_bsize_32.csv b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_32.csv new file mode 100644 index 0000000000000000000000000000000000000000..9d913592da95228e2863da1367a7da55077daa1d --- /dev/null +++ b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_32.csv @@ -0,0 +1,65 @@ +layer,batch_size,optimal_k +0,32,1152 +1,32,1024 +2,32,1152 +3,32,1024 +4,32,640 +5,32,640 +6,32,512 +7,32,512 +8,32,512 +9,32,640 +10,32,640 +11,32,640 +12,32,640 +13,32,768 +14,32,896 +15,32,1408 +16,32,2048 +17,32,2048 +18,32,2944 +19,32,3840 +20,32,5120 +21,32,4736 +22,32,5120 +23,32,4992 +24,32,5888 +25,32,7552 +26,32,8448 +27,32,9600 +28,32,10752 +29,32,11648 +30,32,12672 +31,32,13952 +32,32,15104 +33,32,16768 +34,32,17152 +35,32,18176 +36,32,19456 +37,32,20352 +38,32,20992 +39,32,22016 +40,32,21888 +41,32,22144 +42,32,22528 +43,32,23424 +44,32,23680 +45,32,24192 +46,32,24064 +47,32,24320 +48,32,24192 +49,32,24192 +50,32,24192 +51,32,24448 +52,32,24704 +53,32,24960 +54,32,25216 +55,32,25856 +56,32,26496 +57,32,27008 +58,32,28032 +59,32,28544 +60,32,29056 +61,32,29440 +62,32,29184 +63,32,25344 diff --git a/configs/mlp_router/opt-66b/mlp_router_configs_bsize_4.csv b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_4.csv new file mode 100644 index 0000000000000000000000000000000000000000..705e2f4a5a106cea7ba2b87c85e0775294164da5 --- /dev/null +++ b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_4.csv @@ -0,0 +1,65 @@ +layer,batch_size,optimal_k +0,4,512 +1,4,512 +2,4,512 +3,4,384 +4,4,384 +5,4,384 +6,4,384 +7,4,384 +8,4,384 +9,4,384 +10,4,384 +11,4,384 +12,4,384 +13,4,384 +14,4,384 +15,4,384 +16,4,512 +17,4,512 +18,4,640 +19,4,1024 +20,4,1152 +21,4,1152 +22,4,1152 +23,4,1152 +24,4,1280 +25,4,1664 +26,4,1920 +27,4,2432 +28,4,3072 +29,4,3456 +30,4,4096 +31,4,4352 +32,4,4608 +33,4,5120 +34,4,5248 +35,4,5504 +36,4,6016 +37,4,6272 +38,4,6528 +39,4,6912 +40,4,6784 +41,4,6912 +42,4,7040 +43,4,7296 +44,4,7424 +45,4,7552 +46,4,7424 +47,4,7552 +48,4,7424 +49,4,7296 +50,4,7296 +51,4,7296 +52,4,7424 +53,4,7552 +54,4,7680 +55,4,7936 +56,4,8192 +57,4,8448 +58,4,8960 +59,4,9344 +60,4,9728 +61,4,9984 +62,4,9984 +63,4,8576 diff --git a/configs/mlp_router/opt-66b/mlp_router_configs_bsize_48.csv b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_48.csv new file mode 100644 index 0000000000000000000000000000000000000000..413ea56cfd2322ed03ded616deef203bbafcb6cc --- /dev/null +++ b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_48.csv @@ -0,0 +1,65 @@ +layer,batch_size,optimal_k +0,48,1280 +1,48,1408 +2,48,1280 +3,48,1152 +4,48,1024 +5,48,1024 +6,48,640 +7,48,640 +8,48,640 +9,48,768 +10,48,768 +11,48,768 +12,48,896 +13,48,896 +14,48,1152 +15,48,1664 +16,48,2560 +17,48,2816 +18,48,3968 +19,48,5248 +20,48,6528 +21,48,6016 +22,48,6528 +23,48,6784 +24,48,7552 +25,48,9344 +26,48,10624 +27,48,12288 +28,48,13696 +29,48,14848 +30,48,16000 +31,48,17536 +32,48,18816 +33,48,20736 +34,48,21248 +35,48,22272 +36,48,23552 +37,48,24448 +38,48,25088 +39,48,26112 +40,48,25984 +41,48,26240 +42,48,26624 +43,48,27520 +44,48,27776 +45,48,28288 +46,48,28160 +47,48,28416 +48,48,28288 +49,48,28416 +50,48,28416 +51,48,28672 +52,48,28928 +53,48,29056 +54,48,29312 +55,48,29952 +56,48,30464 +57,48,30976 +58,48,31872 +59,48,32256 +60,48,32640 +61,48,33024 +62,48,32640 +63,48,29184 diff --git a/configs/mlp_router/opt-66b/mlp_router_configs_bsize_64.csv b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_64.csv new file mode 100644 index 0000000000000000000000000000000000000000..2c913fd338c51645918fdc535330b48bef9a5438 --- /dev/null +++ b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_64.csv @@ -0,0 +1,65 @@ +layer,batch_size,optimal_k +0,64,1664 +1,64,1664 +2,64,1408 +3,64,1280 +4,64,1152 +5,64,1024 +6,64,1024 +7,64,768 +8,64,768 +9,64,768 +10,64,896 +11,64,896 +12,64,1024 +13,64,1024 +14,64,1536 +15,64,1920 +16,64,3200 +17,64,3200 +18,64,4608 +19,64,6144 +20,64,7680 +21,64,7296 +22,64,7808 +23,64,8064 +24,64,9088 +25,64,11008 +26,64,12544 +27,64,14464 +28,64,16128 +29,64,17536 +30,64,18816 +31,64,20352 +32,64,21760 +33,64,23680 +34,64,24192 +35,64,25216 +36,64,26496 +37,64,27264 +38,64,27904 +39,64,28928 +40,64,28800 +41,64,29056 +42,64,29440 +43,64,30336 +44,64,30464 +45,64,30976 +46,64,30848 +47,64,31104 +48,64,30976 +49,64,30976 +50,64,31104 +51,64,31360 +52,64,31616 +53,64,31744 +54,64,31872 +55,64,32384 +56,64,32896 +57,64,33280 +58,64,34048 +59,64,34304 +60,64,34688 +61,64,34944 +62,64,34688 +63,64,31616 diff --git a/configs/mlp_router/opt-66b/mlp_router_configs_bsize_8.csv b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_8.csv new file mode 100644 index 0000000000000000000000000000000000000000..17950d869835b5ad1e6614cf32f0ea900c74447c --- /dev/null +++ b/configs/mlp_router/opt-66b/mlp_router_configs_bsize_8.csv @@ -0,0 +1,65 @@ +layer,batch_size,optimal_k +0,8,512 +1,8,640 +2,8,512 +3,8,512 +4,8,384 +5,8,384 +6,8,384 +7,8,384 +8,8,384 +9,8,384 +10,8,384 +11,8,384 +12,8,384 +13,8,384 +14,8,512 +15,8,512 +16,8,768 +17,8,768 +18,8,1152 +19,8,1664 +20,8,1920 +21,8,1792 +22,8,1920 +23,8,1920 +24,8,2048 +25,8,2816 +26,8,3328 +27,8,4096 +28,8,4864 +29,8,5248 +30,8,5632 +31,8,6144 +32,8,6656 +33,8,7424 +34,8,7680 +35,8,8064 +36,8,8832 +37,8,9344 +38,8,9728 +39,8,10368 +40,8,10240 +41,8,10496 +42,8,10496 +43,8,11008 +44,8,11264 +45,8,11520 +46,8,11392 +47,8,11392 +48,8,11264 +49,8,11264 +50,8,11264 +51,8,11264 +52,8,11520 +53,8,11648 +54,8,11904 +55,8,12288 +56,8,12672 +57,8,13056 +58,8,13824 +59,8,14336 +60,8,14848 +61,8,15104 +62,8,15104 +63,8,12672 diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..4e0a01fba2e1aee73f41a51ea50b971a7bfd8834 --- /dev/null +++ b/environment.yml @@ -0,0 +1,479 @@ +name: nvresearch +channels: + - pytorch + - nvidia + - nvidia/label/cuda-12.4.0 + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - anyio=4.2.0=py39h06a4308_0 + - argon2-cffi=21.3.0=pyhd3eb1b0_0 + - argon2-cffi-bindings=21.2.0=py39h7f8727e_0 + - asttokens=2.0.5=pyhd3eb1b0_0 + - async-lru=2.0.4=py39h06a4308_0 + - attrs=23.1.0=py39h06a4308_0 + - babel=2.11.0=py39h06a4308_0 + - backcall=0.2.0=pyhd3eb1b0_0 + - beautifulsoup4=4.12.3=py39h06a4308_0 + - blas=1.0=mkl + - bleach=4.1.0=pyhd3eb1b0_0 + - bottleneck=1.3.7=py39ha9d4c09_0 + - brotli-python=1.0.9=py39h6a678d5_8 + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2025.4.26=hbd8a1cb_0 + - certifi=2025.4.26=pyhd8ed1ab_0 + - cffi=1.16.0=py39h5eee18b_1 + - charset-normalizer=3.3.2=pyhd3eb1b0_0 + - comm=0.2.1=py39h06a4308_0 + - cuda=12.4.0=0 + - cuda-cccl=12.4.99=0 + - cuda-command-line-tools=12.4.0=0 + - cuda-compiler=12.4.0=0 + - cuda-cudart=12.4.99=0 + - cuda-cudart-dev=12.4.99=0 + - cuda-cudart-static=12.4.99=0 + - cuda-cuobjdump=12.4.99=0 + - cuda-cupti=12.4.99=0 + - cuda-cupti-static=12.4.99=0 + - cuda-cuxxfilt=12.4.99=0 + - cuda-demo-suite=12.4.99=0 + - cuda-documentation=12.4.99=0 + - cuda-driver-dev=12.4.99=0 + - cuda-gdb=12.4.99=0 + - cuda-libraries=12.4.0=0 + - cuda-libraries-dev=12.4.0=0 + - cuda-libraries-static=12.4.0=0 + - cuda-nsight=12.4.99=0 + - cuda-nsight-compute=12.4.0=0 + - cuda-nvcc=12.4.99=0 + - cuda-nvdisasm=12.4.99=0 + - cuda-nvml-dev=12.4.99=0 + - cuda-nvprof=12.4.99=0 + - cuda-nvprune=12.4.99=0 + - cuda-nvrtc=12.4.99=0 + - cuda-nvrtc-dev=12.4.99=0 + - cuda-nvrtc-static=12.4.99=0 + - cuda-nvtx=12.4.99=0 + - cuda-nvvp=12.4.99=0 + - cuda-opencl=12.4.99=0 + - cuda-opencl-dev=12.4.99=0 + - cuda-profiler-api=12.4.99=0 + - cuda-runtime=12.4.0=0 + - cuda-sanitizer-api=12.4.99=0 + - cuda-toolkit=12.4.0=0 + - cuda-tools=12.4.0=0 + - cuda-visual-tools=12.4.0=0 + - cyrus-sasl=2.1.28=h52b45da_1 + - dbus=1.13.18=hb2f20db_0 + - debugpy=1.6.7=py39h6a678d5_0 + - decorator=5.1.1=pyhd3eb1b0_0 + - defusedxml=0.7.1=pyhd3eb1b0_0 + - exceptiongroup=1.2.0=py39h06a4308_0 + - executing=0.8.3=pyhd3eb1b0_0 + - expat=2.6.2=h6a678d5_0 + - ffmpeg=4.3=hf484d3e_0 + - fontconfig=2.14.1=h55d465d_3 + - freetype=2.12.1=h4a9f257_0 + - gds-tools=1.9.0.20=0 + - glib=2.78.4=h6a678d5_0 + - glib-tools=2.78.4=h6a678d5_0 + - gmp=6.2.1=h295c915_3 + - gmpy2=2.1.2=py39heeb90bb_0 + - gnutls=3.6.15=he1e5248_0 + - gst-plugins-base=1.14.1=h6a678d5_1 + - gstreamer=1.14.1=h5eee18b_1 + - icu=73.1=h6a678d5_0 + - idna=3.7=py39h06a4308_0 + - importlib-metadata=7.0.1=py39h06a4308_0 + - importlib_metadata=7.0.1=hd3eb1b0_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - ipykernel=6.28.0=py39h06a4308_0 + - ipython=8.15.0=py39h06a4308_0 + - ipywidgets=8.1.2=py39h06a4308_0 + - jedi=0.19.1=py39h06a4308_0 + - jpeg=9e=h5eee18b_3 + - json5=0.9.6=pyhd3eb1b0_0 + - jsonschema-specifications=2023.7.1=py39h06a4308_0 + - jupyter=1.0.0=py39h06a4308_9 + - jupyter-lsp=2.2.0=py39h06a4308_0 + - jupyter_client=8.6.0=py39h06a4308_0 + - jupyter_console=6.6.3=py39h06a4308_0 + - jupyter_core=5.7.2=py39h06a4308_0 + - jupyter_events=0.10.0=py39h06a4308_0 + - jupyter_server=2.14.1=py39h06a4308_0 + - jupyter_server_terminals=0.4.4=py39h06a4308_1 + - jupyterlab=4.0.11=py39h06a4308_0 + - jupyterlab_pygments=0.1.2=py_0 + - jupyterlab_server=2.25.1=py39h06a4308_0 + - jupyterlab_widgets=3.0.10=py39h06a4308_0 + - krb5=1.20.1=h143b758_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libclang=14.0.6=default_hc6dbbc7_1 + - libclang13=14.0.6=default_he11475f_1 + - libcublas=12.4.2.65=0 + - libcublas-dev=12.4.2.65=0 + - libcublas-static=12.4.2.65=0 + - libcufft=11.2.0.44=0 + - libcufft-dev=11.2.0.44=0 + - libcufft-static=11.2.0.44=0 + - libcufile=1.9.0.20=0 + - libcufile-dev=1.9.0.20=0 + - libcufile-static=1.9.0.20=0 + - libcups=2.4.2=h2d74bed_1 + - libcurand=10.3.5.119=0 + - libcurand-dev=10.3.5.119=0 + - libcurand-static=10.3.5.119=0 + - libcusolver=11.6.0.99=0 + - libcusolver-dev=11.6.0.99=0 + - libcusolver-static=11.6.0.99=0 + - libcusparse=12.3.0.142=0 + - libcusparse-dev=12.3.0.142=0 + - libcusparse-static=12.3.0.142=0 + - libdeflate=1.17=h5eee18b_1 + - libedit=3.1.20230828=h5eee18b_0 + - libffi=3.4.4=h6a678d5_1 + - libgcc=12.4.0=h767d61c_2 + - libgcc-ng=12.4.0=h69a702a_2 + - libglib=2.78.4=hdc74915_0 + - libgomp=12.4.0=h767d61c_2 + - libiconv=1.16=h5eee18b_3 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libllvm14=14.0.6=hdb19cb5_3 + - libnpp=12.2.5.2=0 + - libnpp-dev=12.2.5.2=0 + - libnpp-static=12.2.5.2=0 + - libnvfatbin=12.4.99=0 + - libnvfatbin-dev=12.4.99=0 + - libnvjitlink=12.4.99=0 + - libnvjitlink-dev=12.4.99=0 + - libnvjpeg=12.3.1.89=0 + - libnvjpeg-dev=12.3.1.89=0 + - libnvjpeg-static=12.3.1.89=0 + - libpng=1.6.39=h5eee18b_0 + - libpq=12.17=hdbd6064_0 + - libsodium=1.0.18=h7b6447c_0 + - libstdcxx=12.4.0=h8f9b012_2 + - libstdcxx-ng=12.4.0=h4852527_2 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp-base=1.3.2=h5eee18b_0 + - libxcb=1.15=h7f8727e_0 + - libxkbcommon=1.0.1=h097e994_2 + - libxml2=2.13.1=hfdd30dd_2 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_1 + - markupsafe=2.1.3=py39h5eee18b_0 + - matplotlib-inline=0.1.6=py39h06a4308_0 + - mistune=2.0.4=py39h06a4308_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py39h5eee18b_1 + - mkl_fft=1.3.8=py39h5eee18b_0 + - mkl_random=1.2.4=py39hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.3.0=py39h06a4308_0 + - mysql=5.7.24=h721c034_2 + - nbclient=0.8.0=py39h06a4308_0 + - nbconvert=7.10.0=py39h06a4308_0 + - nbformat=5.9.2=py39h06a4308_0 + - ncurses=6.5=he02047a_1 + - nest-asyncio=1.6.0=py39h06a4308_0 + - nettle=3.7.3=hbbd107a_1 + - networkx=3.2.1=py39h06a4308_0 + - notebook=7.0.8=py39h06a4308_2 + - notebook-shim=0.2.3=py39h06a4308_0 + - nsight-compute=2024.1.0.13=0 + - numexpr=2.8.7=py39h85018f9_0 + - numpy=1.26.4=py39h5f9d8c6_0 + - numpy-base=1.26.4=py39hb5e798b_0 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.5.2=he7f1fd0_0 + - openssl=3.3.1=h4bc722e_2 + - overrides=7.4.0=py39h06a4308_0 + - packaging=24.1=py39h06a4308_0 + - pandas=1.5.3=py39h417a72b_0 + - pandocfilters=1.5.0=pyhd3eb1b0_0 + - parso=0.8.3=pyhd3eb1b0_0 + - pcre2=10.42=hebb0a14_1 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pillow=10.4.0=py39h5eee18b_0 + - pip=24.0=py39h06a4308_0 + - platformdirs=3.10.0=py39h06a4308_0 + - ply=3.11=py39h06a4308_0 + - prompt-toolkit=3.0.43=py39h06a4308_0 + - prompt_toolkit=3.0.43=hd3eb1b0_0 + - psutil=5.9.0=py39h5eee18b_0 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - pure_eval=0.2.2=pyhd3eb1b0_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pygments=2.15.1=py39h06a4308_1 + - pyqt=5.15.10=py39h6a678d5_0 + - pyqt5-sip=12.13.0=py39h5eee18b_0 + - pysocks=1.7.1=py39h06a4308_0 + - python=3.9.19=h955ad1f_1 + - python-dateutil=2.9.0post0=py39h06a4308_2 + - python-fastjsonschema=2.16.2=py39h06a4308_0 + - python-json-logger=2.0.7=py39h06a4308_0 + - python-tzdata=2023.3=pyhd3eb1b0_0 + - pytorch-cuda=12.4=hc786d27_6 + - pytorch-mutex=1.0=cuda + - pytz=2024.1=py39h06a4308_0 + - pyyaml=6.0.1=py39h5eee18b_0 + - pyzmq=25.1.2=py39h6a678d5_0 + - qt-main=5.15.2=h53bd1ea_10 + - qtconsole=5.5.1=py39h06a4308_0 + - qtpy=2.4.1=py39h06a4308_0 + - readline=8.2=h5eee18b_0 + - referencing=0.30.2=py39h06a4308_0 + - requests=2.32.3=py39h06a4308_0 + - rfc3339-validator=0.1.4=py39h06a4308_0 + - rfc3986-validator=0.1.1=py39h06a4308_0 + - rpds-py=0.10.6=py39hb02cf49_0 + - send2trash=1.8.2=py39h06a4308_0 + - setuptools=72.1.0=py39h06a4308_0 + - sip=6.7.12=py39h6a678d5_0 + - six=1.16.0=pyhd3eb1b0_1 + - sniffio=1.3.0=py39h06a4308_0 + - soupsieve=2.5=py39h06a4308_0 + - sqlite=3.45.3=h5eee18b_0 + - stack_data=0.2.0=pyhd3eb1b0_0 + - tbb=2021.8.0=hdb19cb5_0 + - terminado=0.17.1=py39h06a4308_0 + - tinycss2=1.2.1=py39h06a4308_0 + - tk=8.6.14=h39e8969_0 + - tomli=2.0.1=py39h06a4308_0 + - tornado=6.4.1=py39h5eee18b_0 + - traitlets=5.14.3=py39h06a4308_0 + - tzdata=2024a=h04d1e81_0 + - urllib3=2.2.2=py39h06a4308_0 + - wcwidth=0.2.5=pyhd3eb1b0_0 + - webencodings=0.5.1=py39h06a4308_1 + - websocket-client=1.8.0=py39h06a4308_0 + - wheel=0.43.0=py39h06a4308_0 + - widgetsnbextension=4.0.10=py39h06a4308_0 + - xz=5.4.6=h5eee18b_1 + - yaml=0.2.5=h7b6447c_0 + - zeromq=4.3.5=h6a678d5_0 + - zipp=3.17.0=py39h06a4308_0 + - zlib=1.2.13=h5eee18b_1 + - zstd=1.5.5=hc292b87_2 + - pip: + - absl-py==2.1.0 + - accelerate==0.33.0 + - adjusttext==1.3.0 + - aiohappyeyeballs==2.3.4 + - aiohttp==3.10.1 + - aiosignal==1.3.1 + - airportsdata==20250224 + - altair==5.4.1 + - ampselgemm==0.0.0 + - annotated-types==0.7.0 + - anykeystore==0.2 + - apex==0.1 + - astor==0.8.1 + - async-timeout==4.0.3 + - balanced-kmeans==0.1.0 + - basicgemm==0.0.0 + - bitsandbytes==0.43.3 + - blake3==1.0.4 + - blinker==1.9.0 + - cachetools==5.5.0 + - chardet==5.2.0 + - click==8.1.7 + - cloudpickle==3.1.1 + - colorama==0.4.6 + - compressed-tensors==0.9.2 + - contexttimer==0.3.3 + - contourpy==1.2.1 + - cryptacular==1.6.2 + - cuda-python==12.6.0 + - cupy-cuda12x==13.4.1 + - cycler==0.12.1 + - dataproperty==1.1.0 + - datasets==2.20.0 + - depyf==0.18.0 + - dill==0.3.8 + - diskcache==5.6.3 + - distro==1.9.0 + - dnspython==2.7.0 + - einops==0.8.0 + - email-validator==2.2.0 + - evaluate==0.4.3 + - fastapi==0.115.12 + - fastapi-cli==0.0.7 + - fastrlock==0.8.3 + - filelock==3.18.0 + - flash-attn==2.6.3 + - fonttools==4.53.1 + - frozenlist==1.4.1 + - fsspec==2024.5.0 + - fused-dense-lib==0.0.0 + - fuzzywuzzy==0.18.0 + - gguf==0.10.0 + - gitdb==4.0.11 + - gitpython==3.1.43 + - greenlet==3.0.3 + - grouped-gemm-v1==0.0.0 + - h11==0.14.0 + - httpcore==1.0.7 + - httptools==0.6.4 + - httpx==0.28.1 + - huggingface-hub==0.28.1 + - hupper==1.12.1 + - immutabledict==4.2.1 + - importlib-resources==6.4.0 + - iniconfig==2.0.0 + - interegular==0.3.3 + - jieba==0.42.1 + - jinja2==3.1.6 + - jiter==0.9.0 + - joblib==1.4.2 + - jsonlines==4.0.0 + - jsonschema==4.23.0 + - k-means-constrained==0.7.3 + - kiwisolver==1.4.5 + - lark==1.2.2 + - lightning-utilities==0.11.9 + - llguidance==0.7.11 + - llvmlite==0.43.0 + - lm-eval==0.4.8 + - lm-format-enforcer==0.10.11 + - lxml==5.3.0 + - markdown-it-py==3.0.0 + - matplotlib==3.9.0 + - mbstrdecoder==1.1.4 + - mdurl==0.1.2 + - mistral-common==1.5.4 + - more-itertools==10.6.0 + - msgpack==1.1.0 + - msgspec==0.19.0 + - multidict==6.0.5 + - multiprocess==0.70.16 + - narwhals==1.13.5 + - ninja==1.11.1.4 + - nltk==3.9.1 + - numba==0.60.0 + - nvidia-cublas-cu12==12.4.5.8 + - nvidia-cuda-cupti-cu12==12.4.127 + - nvidia-cuda-nvrtc-cu12==12.4.127 + - nvidia-cuda-runtime-cu12==12.4.127 + - nvidia-cudnn-cu12==9.1.0.70 + - nvidia-cufft-cu12==11.2.1.3 + - nvidia-curand-cu12==10.3.5.147 + - nvidia-cusolver-cu12==11.6.1.9 + - nvidia-cusparse-cu12==12.3.1.170 + - nvidia-cusparselt-cu12==0.6.2 + - nvidia-cutlass==3.5.0.0 + - nvidia-nccl-cu12==2.21.5 + - nvidia-nvjitlink-cu12==12.4.127 + - nvidia-nvtx-cu12==12.4.127 + - oauthlib==3.2.2 + - openai==1.70.0 + - opencv-python-headless==4.11.0.86 + - ortools==9.11.4210 + - outlines==0.1.11 + - outlines-core==0.1.26 + - partial-json-parser==0.2.1.1.post5 + - pastedeploy==3.1.0 + - pathvalidate==3.2.3 + - pbkdf2==1.3 + - peft==0.14.0 + - plaster==1.1.2 + - plaster-pastedeploy==1.0.1 + - pluggy==1.5.0 + - portalocker==3.1.1 + - prometheus-client==0.21.1 + - prometheus-fastapi-instrumentator==7.1.0 + - protobuf==5.26.1 + - py-cpuinfo==9.0.0 + - pyarrow==17.0.0 + - pyarrow-hotfix==0.6 + - pybind11==2.13.6 + - pycountry==24.6.1 + - pydantic==2.11.1 + - pydantic-core==2.33.0 + - pydeck==0.9.1 + - pydot==3.0.1 + - pyparsing==3.1.2 + - pyramid==1.5.8 + - pyramid-mailer==0.15.1 + - pytablewriter==1.2.1 + - pytest==8.3.3 + - python-dotenv==1.1.0 + - python-multipart==0.0.20 + - python3-openid==3.2.0 + - ray==2.44.1 + - regex==2024.7.24 + - repoze-lru==0.7 + - repoze-sendmail==4.4.1 + - requests-oauthlib==2.0.0 + - rich==13.9.4 + - rich-toolkit==0.14.1 + - rouge==1.0.1 + - rouge-score==0.1.2 + - sacrebleu==2.5.1 + - safetensors==0.4.4 + - scikit-learn==1.5.2 + - scipy==1.13.1 + - seaborn==0.13.2 + - sentencepiece==0.2.0 + - shellingham==1.5.4 + - smmap==5.0.1 + - sqlalchemy==2.0.32 + - sqlitedict==2.1.0 + - starlette==0.46.1 + - streamlit==1.40.1 + - sympy==1.13.1 + - tabledata==1.3.4 + - tabulate==0.9.0 + - tcolorpy==0.1.7 + - tenacity==9.0.0 + - threadpoolctl==3.5.0 + - tiktoken==0.9.0 + - tokenizers==0.21.1 + - toml==0.10.2 + - torch==2.6.0 + - torchaudio==2.6.0 + - torchmetrics==1.6.0 + - torchvision==0.21.0 + - tqdm==4.66.5 + - tqdm-multiprocess==0.0.11 + - transaction==4.0 + - transformers==4.50.3 + - translationstring==1.4 + - treelib==1.7.0 + - triton==3.2.0 + - triton-nightly==3.0.0.post20240716052845 + - typepy==1.3.4 + - typer==0.15.2 + - typing-extensions==4.13.0 + - typing-inspection==0.4.0 + - uvicorn==0.34.0 + - uvloop==0.21.0 + - velruse==1.1.1 + - venusian==3.1.0 + - vllm==0.8.2 + - watchdog==6.0.0 + - watchfiles==1.0.4 + - webob==1.8.7 + - websockets==15.0.1 + - word2number==1.1 + - wtforms==3.1.2 + - wtforms-recaptcha==0.3.2 + - xformers==0.0.29.post2 + - xgrammar==0.1.16 + - xxhash==3.4.1 + - yarl==1.9.4 + - zope-deprecation==5.0 + - zope-interface==7.0.1 + - zope-sqlalchemy==3.1 + - zstandard==0.23.0 diff --git a/hf_models/llama/__init__.py b/hf_models/llama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..41febe935576e984146d660895148fd04af43736 --- /dev/null +++ b/hf_models/llama/__init__.py @@ -0,0 +1,116 @@ +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from transformers.utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_sentencepiece_available, + is_tokenizers_available, + is_torch_available, +) + + +_import_structure = { + "configuration_llama": ["LlamaConfig"], +} + +try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_llama"] = ["LlamaTokenizer"] + +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_llama"] = [ + "LlamaForCausalLM", + "LlamaModel", + "LlamaPreTrainedModel", + "LlamaForSequenceClassification", + "LlamaForQuestionAnswering", + "LlamaForTokenClassification", + ] + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_llama"] = ["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"] + + +if TYPE_CHECKING: + from .configuration_llama import LlamaConfig + + try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_llama import LlamaTokenizer + + try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_llama_fast import LlamaTokenizerFast + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_llama import ( + LlamaForCausalLM, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaModel, + LlamaPreTrainedModel, + ) + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_llama import FlaxLlamaForCausalLM, FlaxLlamaModel, FlaxLlamaPreTrainedModel + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/hf_models/llama/configuration_llama.py b/hf_models/llama/configuration_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4f4db62d84d34e2fb8541ababbfe17280e664e --- /dev/null +++ b/hf_models/llama/configuration_llama.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LLaMA model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class LlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/hf_models/llama/convert_llama_weights_to_hf.py b/hf_models/llama/convert_llama_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..a75ce5245eee60a170e24e480841ee8e268d98db --- /dev/null +++ b/hf_models/llama/convert_llama_weights_to_hf.py @@ -0,0 +1,479 @@ +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os +import shutil +import warnings +from typing import List + +import torch + +from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast +from transformers.convert_slow_tokenizer import TikTokenConverter + + +try: + from transformers import LlamaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + LlamaTokenizerFast = None + +""" +Sample usage: + +``` +python src/transformers/models/llama/convert_llama_weights_to_hf.py \ + --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import LlamaForCausalLM, LlamaTokenizer + +model = LlamaForCausalLM.from_pretrained("/output/path") +tokenizer = LlamaTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). + +If you want you tokenizer to add a bos automatically you should update the tokenizer._tokenizers.post_processor: + +```py +from tokenizers import processors +bos = "<|begin_of_text|>" +tokenizer._tokenizers.post_processor = processors.Sequence( + [ + processors.ByteLevel(trim_offsets=False), + processors.TemplateProcessing( + single=f"{bos}:0 $A:0", + pair=f"{bos}:0 $A:0 {bos}:1 $B:1", + special_tokens=[ + (bos, tokenizer.encode(bos)), + ], + ), + ] +) +``` +""" + +NUM_SHARDS = { + "7B": 1, + "8B": 1, + "8Bf": 1, + "7Bf": 1, + "13B": 2, + "13Bf": 2, + "34B": 4, + "30B": 4, + "65B": 8, + "70B": 8, + "70Bf": 8, + "405B": 8, + "405B-MP16": 16, +} + +CONTEXT_LENGTH_FOR_VERSION = {"3.1": 131072, "3": 8192, "2": 4096, "1": 2048} + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model( + model_path, + input_base_path, + model_size=None, + safe_serialization=True, + llama_version="1", + vocab_size=None, + num_shards=None, + instruct=False, +): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + params = read_json(os.path.join(input_base_path, "params.json")) + num_shards = NUM_SHARDS[model_size] if num_shards is None else num_shards + params = params.get("model", params) + n_layers = params["n_layers"] + n_heads = params["n_heads"] + n_heads_per_shard = n_heads // num_shards + dim = params["dim"] + dims_per_head = dim // n_heads + base = params.get("rope_theta", 10000.0) + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + if base > 10000.0 and float(llama_version) < 3: + max_position_embeddings = 16384 + else: + max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[llama_version] + + if params.get("n_kv_heads", None) is not None: + num_key_value_heads = params["n_kv_heads"] # for GQA / MQA + num_key_value_heads_per_shard = num_key_value_heads // num_shards + key_value_dim = dims_per_head * num_key_value_heads + else: # compatibility with other checkpoints + num_key_value_heads = n_heads + num_key_value_heads_per_shard = n_heads_per_shard + key_value_dim = dim + + # permute for sliced rotary + def permute(w, n_heads, dim1=dim, dim2=dim): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + # Load weights + if num_shards == 1: + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") + else: + # Sharded + checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")]) + print("Loading in order:", checkpoint_list) + loaded = [torch.load(os.path.join(input_base_path, file), map_location="cpu") for file in checkpoint_list] + param_count = 0 + index_dict = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + if num_shards == 1: + # Unsharded + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads + ), + f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wk.weight"], + n_heads=num_key_value_heads, + dim1=key_value_dim, + ), + f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], + f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], + } + else: + # Sharded + # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share + # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is + # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. + + state_dict = { + f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) + for i in range(len(loaded)) + ], + dim=0, + ).reshape(dim, dim), + n_heads=n_heads, + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( + num_key_value_heads_per_shard, dims_per_head, dim + ) + for i in range(len(loaded)) + ], + dim=0, + ).reshape(key_value_dim, dim), + num_key_value_heads, + key_value_dim, + dim, + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( + num_key_value_heads_per_shard, dims_per_head, dim + ) + for i in range(len(loaded)) + ], + dim=0, + ).reshape(key_value_dim, dim) + + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(len(loaded))], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(len(loaded))], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(len(loaded))], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(len(loaded))], dim=0 + ) + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + if num_shards == 1: + # Unsharded + state_dict = { + "model.embed_tokens.weight": loaded["tok_embeddings.weight"], + "model.norm.weight": loaded["norm.weight"], + "lm_head.weight": loaded["output.weight"], + } + else: + concat_dim = 0 if llama_version in ["3", "3.1"] else 1 + state_dict = { + "model.norm.weight": loaded[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat( + [loaded[i]["tok_embeddings.weight"] for i in range(len(loaded))], dim=concat_dim + ), + "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(len(loaded))], dim=0), + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 + multiple_of = params["multiple_of"] if "multiple_of" in params else 256 + + if llama_version in ["3", "3.1"]: + bos_token_id = 128000 + + if instruct: + eos_token_id = [128001, 128008, 128009] + else: + eos_token_id = 128001 + else: + bos_token_id = 1 + eos_token_id = 2 + + config = LlamaConfig( + hidden_size=dim, + intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), + num_attention_heads=params["n_heads"], + num_hidden_layers=params["n_layers"], + rms_norm_eps=params["norm_eps"], + num_key_value_heads=num_key_value_heads, + vocab_size=vocab_size, + rope_theta=base, + max_position_embeddings=max_position_embeddings, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + ) + config.save_pretrained(tmp_model_path) + + if instruct: + generation_config = GenerationConfig( + do_sample=True, + temperature=0.6, + top_p=0.9, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + ) + generation_config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + print("Loading the checkpoint in a Llama model.") + model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) + # Avoid saving this as part of the config. + del model.config._name_or_path + model.config.torch_dtype = torch.float16 + print("Saving in the Transformers format.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path, ignore_errors=True) + + +class Llama3Converter(TikTokenConverter): + def __init__(self, vocab_file, special_tokens=None, instruct=False, model_max_length=None, **kwargs): + super().__init__(vocab_file, **kwargs) + tokenizer = self.converted() + chat_template = ( + "{% set loop_messages = messages %}" + "{% for message in loop_messages %}" + "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}" + "{% if loop.index0 == 0 %}" + "{% set content = bos_token + content %}" + "{% endif %}" + "{{ content }}" + "{% endfor %}" + "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" + ) + tokenizer.add_special_tokens(special_tokens) + + self.tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>", + chat_template=chat_template if instruct else None, + model_input_names=["input_ids", "attention_mask"], + model_max_length=model_max_length, + ) + + +def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version="2", special_tokens=None, instruct=False): + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + if llama_version in ["3", "3.1"]: + tokenizer = Llama3Converter( + input_tokenizer_path, special_tokens, instruct, model_max_length=CONTEXT_LENGTH_FOR_VERSION[llama_version] + ).tokenizer + else: + tokenizer = tokenizer_class(input_tokenizer_path) + print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") + tokenizer.save_pretrained(tokenizer_path) + return tokenizer + + +DEFAULT_LLAMA_SPECIAL_TOKENS = { + "3": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)], + "3.1": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|reserved_special_token_2|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + + [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)], +} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of LLaMA weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--model_size", + default=None, + help="'f' Deprecated in favor of `num_shards`: models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`." + ) + # Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. + parser.add_argument( + "--llama_version", + choices=["1", "2", "3", "3.1"], + default="1", + type=str, + help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size", + ) + parser.add_argument( + "--num_shards", + default=None, + type=int, + help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth", + ) + parser.add_argument( + "--special_tokens", + default=None, + type=List[str], + help="The list of special tokens that should be added to the model.", + ) + parser.add_argument( + "--instruct", + default=False, + type=bool, + help="Whether the model is an instruct model or not. Will affect special tokens for llama 3.1.", + ) + args = parser.parse_args() + if args.model_size is None and args.num_shards is None: + raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`") + if args.special_tokens is None: + # no special tokens by default + args.special_tokens = DEFAULT_LLAMA_SPECIAL_TOKENS.get(str(args.llama_version), []) + + spm_path = os.path.join(args.input_dir, "tokenizer.model") + vocab_size = len( + write_tokenizer( + args.output_dir, + spm_path, + llama_version=args.llama_version, + special_tokens=args.special_tokens, + instruct=args.instruct, + ) + ) + if args.model_size != "tokenizer_only": + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + model_size=args.model_size, + safe_serialization=args.safe_serialization, + llama_version=args.llama_version, + vocab_size=vocab_size, + num_shards=args.num_shards, + instruct=args.instruct, + ) + + +if __name__ == "__main__": + main() diff --git a/hf_models/llama/modeling_flax_llama.py b/hf_models/llama/modeling_flax_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..90e60a15785b85a8d4a6dcce440bcea671cbaf05 --- /dev/null +++ b/hf_models/llama/modeling_flax_llama.py @@ -0,0 +1,750 @@ +# coding=utf-8 +# Copyright 2023 Meta AI, EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax LLaMA model.""" + +from functools import partial +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput +from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" +_CHECKPOINT_FOR_DOC = "afmck/testing-llama-tiny" +_REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_llama_3b_v2" + +LLAMA_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`LlamaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or + `jax.numpy.bfloat16`. + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def create_sinusoidal_positions(num_pos, dim): + inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim)) + freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") + + emb = np.concatenate((freqs, freqs), axis=-1) + out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) + return jnp.array(out[:, :, :num_pos]) + + +def rotate_half(tensor): + """Rotates half the hidden dims of the input.""" + rotate_half_tensor = jnp.concatenate( + (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1 + ) + return rotate_half_tensor + + +def apply_rotary_pos_emb(tensor, sin_pos, cos_pos): + return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos) + + +class FlaxLlamaRMSNorm(nn.Module): + config: LlamaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.epsilon = self.config.rms_norm_eps + self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size) + + def __call__(self, hidden_states): + variance = jnp.asarray(hidden_states, dtype=jnp.float32) + variance = jnp.power(variance, 2) + variance = variance.mean(-1, keepdims=True) + # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt` + hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon) + + return self.weight * jnp.asarray(hidden_states, dtype=self.dtype) + + +class FlaxLlamaRotaryEmbedding(nn.Module): + config: LlamaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + head_dim = self.config.hidden_size // self.config.num_attention_heads + self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim) + + def __call__(self, key, query, position_ids): + sincos = self.sincos[position_ids] + sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1) + + key = apply_rotary_pos_emb(key, sin_pos, cos_pos) + query = apply_rotary_pos_emb(query, sin_pos, cos_pos) + + key = jnp.asarray(key, dtype=self.dtype) + query = jnp.asarray(query, dtype=self.dtype) + + return key, query + + +class FlaxLlamaAttention(nn.Module): + config: LlamaConfig + dtype: jnp.dtype = jnp.float32 + causal: bool = True + is_cross_attention: bool = False + + def setup(self): + config = self.config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 + + dense = partial( + nn.Dense, + use_bias=config.attention_bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + self.q_proj = dense(self.num_heads * self.head_dim) + self.k_proj = dense(self.num_key_value_heads * self.head_dim) + self.v_proj = dense(self.num_key_value_heads * self.head_dim) + self.o_proj = dense(self.embed_dim) + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") + self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype) + + def _split_heads(self, hidden_states, num_heads): + return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + position_ids, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + ): + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_heads) + key = self._split_heads(key, self.num_key_value_heads) + value = self._split_heads(value, self.num_key_value_heads) + + key, query = self.rotary_emb(key, query, position_ids) + + query_length, key_length = query.shape[1], key.shape[1] + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + + batch_size = hidden_states.shape[0] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + + dropout_rng = None + if not deterministic and self.config.attention_dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.has_variable("cache", "cached_key") or init_cache: + key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) + + key = jnp.repeat(key, self.num_key_value_groups, axis=2) + value = jnp.repeat(value, self.num_key_value_groups, axis=2) + + # transform boolean mask into float mask + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + + # usual dot product attention + attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype + attn_weights = dot_product_attention_weights( + query, + key, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_dropout, + deterministic=deterministic, + dtype=attention_dtype, + ) + + if self.attention_softmax_in_fp32: + attn_weights = attn_weights.astype(self.dtype) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) + attn_output = self._merge_heads(attn_output) + attn_output = self.o_proj(attn_output) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxLlamaMLP(nn.Module): + config: LlamaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + embed_dim = self.config.hidden_size + inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim + + kernel_init = jax.nn.initializers.normal(self.config.initializer_range) + self.act = ACT2FN[self.config.hidden_act] + + self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) + self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) + self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) + + def __call__(self, hidden_states): + up_proj_states = self.up_proj(hidden_states) + gate_states = self.act(self.gate_proj(hidden_states)) + + hidden_states = self.down_proj(up_proj_states * gate_states) + return hidden_states + + +class FlaxLlamaDecoderLayer(nn.Module): + config: LlamaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.input_layernorm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype) + self.self_attn = FlaxLlamaAttention(self.config, dtype=self.dtype) + self.post_attention_layernorm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype) + self.mlp = FlaxLlamaMLP(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + outputs = self.self_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) + # residual connection + attn_output = outputs[0] + hidden_states = residual + attn_output + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + hidden_states + + return (hidden_states,) + outputs[1:] + + +# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Llama, GPT_NEO->LLAMA, transformer->model +class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LlamaConfig + base_model_prefix = "model" + module_class: nn.Module = None + + def __init__( + self, + config: LlamaConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + params: dict = None, + past_key_values: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_ids.shape + + if position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") + + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxLlamaAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + False, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxLlamaLayerCollection(nn.Module): + config: LlamaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.blocks = [ + FlaxLlamaDecoderLayer(self.config, dtype=self.dtype, name=str(i)) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for block in self.blocks: + if output_hidden_states: + all_hidden_states += (hidden_states,) + layer_outputs = block( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # this contains possible `None` values - `FlaxLlamaModule` will filter them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + return outputs + + +class FlaxLlamaModule(nn.Module): + config: LlamaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.hidden_size = self.config.hidden_size + embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) + self.embed_tokens = nn.Embed( + self.config.vocab_size, + self.hidden_size, + embedding_init=embedding_init, + dtype=self.dtype, + ) + self.layers = FlaxLlamaLayerCollection(self.config, dtype=self.dtype) + self.norm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + input_embeds = self.embed_tokens(input_ids.astype("i4")) + + outputs = self.layers( + input_embeds, + position_ids=position_ids, + attention_mask=attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +@add_start_docstrings( + "The bare Llama Model transformer outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class FlaxLlamaModel(FlaxLlamaPreTrainedModel): + module_class = FlaxLlamaModule + + +append_call_sample_docstring( + FlaxLlamaModel, + _CHECKPOINT_FOR_DOC, + FlaxBaseModelOutput, + _CONFIG_FOR_DOC, + real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, +) + + +class FlaxLlamaForCausalLMModule(nn.Module): + config: LlamaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.model = FlaxLlamaModule(self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + outputs = self.model( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + +@add_start_docstrings( + """ + The Llama Model transformer with a language modeling head (linear layer) on top. + """, + LLAMA_START_DOCSTRING, +) +# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Llama +class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel): + module_class = FlaxLlamaForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since Llama uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxLlamaForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutput, + _CONFIG_FOR_DOC, + real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, +) diff --git a/hf_models/llama/modeling_llama.py b/hf_models/llama/modeling_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5adc071510fc72ff05682a421b6acf85d491ee --- /dev/null +++ b/hf_models/llama/modeling_llama.py @@ -0,0 +1,1728 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + # BaseModelOutputWithPast, + # CausalLMOutputWithPast, + # QuestionAnsweringModelOutput, + # SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) + +from .modeling_outputs import ExtendedBaseModelOutputWithPast as BaseModelOutputWithPast +from .modeling_outputs import ExtendedCausalLMOutputWithPast as CausalLMOutputWithPast +from .modeling_outputs import ExtendedQuestionAnsweringModelOutput as QuestionAnsweringModelOutput +from .modeling_outputs import ExtendedSequenceClassifierOutputWithPast as SequenceClassifierOutputWithPast + + +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." + ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." + ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + # self.act_fn = ACT2FN[config.hidden_act] + if config.hidden_act in ACT2FN: + self.act_fn = ACT2FN[config.hidden_act] + elif config.hidden_act == "shiftrelu": + def shifted_relu(x): + return torch.nn.functional.relu(x - config.hidden_act_param) + self.act_fn = shifted_relu + elif config.hidden_act == "fatrelu": + def fat_relu(x): + new_x = torch.zeros_like(x) + mask = torch.ge(x, config.hidden_act_param) + new_x[mask] = x[mask] + return new_x + self.act_fn = fat_relu + + def forward(self, x, output_mlp_activation=False): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + activation = self.act_fn(self.gate_proj(x)) + product = activation * self.up_proj(x) + down_proj = self.down_proj(product) + if output_mlp_activation: + return down_proj, activation + return down_proj, None + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + output_attn_output: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + if output_attn_output: + attn_head_output = attn_output + else: + attn_head_output = None + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, attn_head_output + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + output_mlp_activation: Optional[bool] = False, + output_attn_output: Optional[bool] = False, + output_router_inputs: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + if output_router_inputs: + router_input = hidden_states.detach().cpu() # this moves the tensor to cpu + + # Self Attention + hidden_states, self_attn_weights, present_key_value, attn_head_output = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + output_attn_output=output_attn_output, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + if output_attn_output: + attn_head_output = attn_head_output.detach().cpu() # use this for collecting attn router training data + # attn_head_output = hidden_states.detach().cpu() # use this for attn importance data + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, mlp_activation = self.mlp(hidden_states, output_mlp_activation=output_mlp_activation) + + if output_mlp_activation: + mlp_activation = mlp_activation.detach().cpu() + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_mlp_activation: + outputs += (mlp_activation,) + + if output_attn_output: + outputs += (attn_head_output,) + + if output_router_inputs: + outputs += (router_input,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_mlp_activations = () if output_mlp_activation else None + all_attn_outputs = () if output_attn_output else None + all_router_inputs = () if output_router_inputs else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + index = 1 + if output_attentions: + all_self_attns += (layer_outputs[index],) + index += 1 + if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + next_decoder_cache = layer_outputs[index] + index += 1 + + if output_mlp_activation: + # calculate the correct index depending on the use_cache, output_attentions + all_mlp_activations += (layer_outputs[index],) + index += 1 + + if output_attn_output: + all_attn_outputs += (layer_outputs[index],) + index += 1 + + if output_router_inputs: + # assert index < len(layer_outputs), "router_input index out of range" + all_router_inputs += (layer_outputs[index],) + index += 1 + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_mlp_activations] if v is not None) + + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=next_cache, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # ) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mlp_activations=all_mlp_activations, + attn_outputs = all_attn_outputs, + router_inputs = all_router_inputs, + ) + # if output_mlp_activation: + # output.mlp_activations = all_mlp_activations + + # if output_attn_output: + # output.attn_outputs = all_attn_outputs + + return output + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + result = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + mlp_activations=outputs.mlp_activations, + attn_outputs=outputs.attn_outputs, + router_inputs=outputs.router_inputs, + ) + # if output_mlp_activation: + # result.mlp_activations = outputs.mlp_activations + + return result + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if inputs_embeds is not None: + batch_size, sequence_length = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/hf_models/llama/modeling_llama_attn_imp.py b/hf_models/llama/modeling_llama_attn_imp.py new file mode 100644 index 0000000000000000000000000000000000000000..58fda3a6c20746bd15ef09e809ad7b147d2483ff --- /dev/null +++ b/hf_models/llama/modeling_llama_attn_imp.py @@ -0,0 +1,1728 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + # BaseModelOutputWithPast, + # CausalLMOutputWithPast, + # QuestionAnsweringModelOutput, + # SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) + +from .modeling_outputs import ExtendedBaseModelOutputWithPast as BaseModelOutputWithPast +from .modeling_outputs import ExtendedCausalLMOutputWithPast as CausalLMOutputWithPast +from .modeling_outputs import ExtendedQuestionAnsweringModelOutput as QuestionAnsweringModelOutput +from .modeling_outputs import ExtendedSequenceClassifierOutputWithPast as SequenceClassifierOutputWithPast + + +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." + ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." + ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + # self.act_fn = ACT2FN[config.hidden_act] + if config.hidden_act in ACT2FN: + self.act_fn = ACT2FN[config.hidden_act] + elif config.hidden_act == "shiftrelu": + def shifted_relu(x): + return torch.nn.functional.relu(x - config.hidden_act_param) + self.act_fn = shifted_relu + elif config.hidden_act == "fatrelu": + def fat_relu(x): + new_x = torch.zeros_like(x) + mask = torch.ge(x, config.hidden_act_param) + new_x[mask] = x[mask] + return new_x + self.act_fn = fat_relu + + def forward(self, x, output_mlp_activation=False): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + activation = self.act_fn(self.gate_proj(x)) + product = activation * self.up_proj(x) + down_proj = self.down_proj(product) + if output_mlp_activation: + return down_proj, activation + return down_proj, None + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + output_attn_output: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + if output_attn_output: + attn_head_output = attn_output + else: + attn_head_output = None + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, attn_head_output + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + output_mlp_activation: Optional[bool] = False, + output_attn_output: Optional[bool] = False, + output_router_inputs: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + if output_router_inputs: + router_input = hidden_states.detach().cpu() # this moves the tensor to cpu + + # Self Attention + hidden_states, self_attn_weights, present_key_value, attn_head_output = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + output_attn_output=output_attn_output, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + if output_attn_output: + # attn_head_output = attn_head_output.detach().cpu() # use this for collecting attn router training data + attn_head_output = hidden_states.detach().cpu() # use this for attn importance data + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, mlp_activation = self.mlp(hidden_states, output_mlp_activation=output_mlp_activation) + + if output_mlp_activation: + mlp_activation = mlp_activation.detach().cpu() + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_mlp_activation: + outputs += (mlp_activation,) + + if output_attn_output: + outputs += (attn_head_output,) + + if output_router_inputs: + outputs += (router_input,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_mlp_activations = () if output_mlp_activation else None + all_attn_outputs = () if output_attn_output else None + all_router_inputs = () if output_router_inputs else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + index = 1 + if output_attentions: + all_self_attns += (layer_outputs[index],) + index += 1 + if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + next_decoder_cache = layer_outputs[index] + index += 1 + + if output_mlp_activation: + # calculate the correct index depending on the use_cache, output_attentions + all_mlp_activations += (layer_outputs[index],) + index += 1 + + if output_attn_output: + all_attn_outputs += (layer_outputs[index],) + index += 1 + + if output_router_inputs: + # assert index < len(layer_outputs), "router_input index out of range" + all_router_inputs += (layer_outputs[index],) + index += 1 + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_mlp_activations] if v is not None) + + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=next_cache, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # ) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mlp_activations=all_mlp_activations, + attn_outputs = all_attn_outputs, + router_inputs = all_router_inputs, + ) + # if output_mlp_activation: + # output.mlp_activations = all_mlp_activations + + # if output_attn_output: + # output.attn_outputs = all_attn_outputs + + return output + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + result = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + mlp_activations=outputs.mlp_activations, + attn_outputs=outputs.attn_outputs, + router_inputs=outputs.router_inputs, + ) + # if output_mlp_activation: + # result.mlp_activations = outputs.mlp_activations + + return result + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if inputs_embeds is not None: + batch_size, sequence_length = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/hf_models/llama/modeling_outputs.py b/hf_models/llama/modeling_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..56b954de3966e698ae6c4649327623097512a91d --- /dev/null +++ b/hf_models/llama/modeling_outputs.py @@ -0,0 +1,138 @@ +from dataclasses import dataclass +from typing import Optional, Tuple +import torch + +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) + +@dataclass +class ExtendedBaseModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + router_inputs: Optional[Tuple[torch.FloatTensor, ...]] = None + mlp_activations: Optional[Tuple[torch.FloatTensor, ...]] = None + attn_outputs: Optional[Tuple[torch.FloatTensor, ...]] = None + +@dataclass +class ExtendedCausalLMOutputWithPast(CausalLMOutputWithPast): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + router_inputs: Optional[Tuple[torch.FloatTensor, ...]] = None + mlp_activations: Optional[Tuple[torch.FloatTensor, ...]] = None + attn_outputs: Optional[Tuple[torch.FloatTensor, ...]] = None + +@dataclass +class ExtendedQuestionAnsweringModelOutput(QuestionAnsweringModelOutput): + """ + Base class for outputs of question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + router_inputs: Optional[Tuple[torch.FloatTensor, ...]] = None + mlp_activations: Optional[Tuple[torch.FloatTensor, ...]] = None + attn_outputs: Optional[Tuple[torch.FloatTensor, ...]] = None + +@dataclass +class ExtendedSequenceClassifierOutputWithPast(SequenceClassifierOutputWithPast): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + router_inputs: Optional[Tuple[torch.FloatTensor, ...]] = None + mlp_activations: Optional[Tuple[torch.FloatTensor, ...]] = None + attn_outputs: Optional[Tuple[torch.FloatTensor, ...]] = None \ No newline at end of file diff --git a/hf_models/llama/modeling_sparse_llama_mha_topk.py b/hf_models/llama/modeling_sparse_llama_mha_topk.py new file mode 100644 index 0000000000000000000000000000000000000000..abc9063331b7950f628bc2d09f8ff83c4d2cd65b --- /dev/null +++ b/hf_models/llama/modeling_sparse_llama_mha_topk.py @@ -0,0 +1,1803 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + # BaseModelOutputWithPast, + # CausalLMOutputWithPast, + # QuestionAnsweringModelOutput, + # SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) + +from .modeling_outputs import ExtendedBaseModelOutputWithPast as BaseModelOutputWithPast +from .modeling_outputs import ExtendedCausalLMOutputWithPast as CausalLMOutputWithPast +from .modeling_outputs import ExtendedQuestionAnsweringModelOutput as QuestionAnsweringModelOutput +from .modeling_outputs import ExtendedSequenceClassifierOutputWithPast as SequenceClassifierOutputWithPast + + +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." + ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." + ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + # self.act_fn = ACT2FN[config.hidden_act] + if config.hidden_act in ACT2FN: + self.act_fn = ACT2FN[config.hidden_act] + elif config.hidden_act == "shiftrelu": + def shifted_relu(x): + return torch.nn.functional.relu(x - config.hidden_act_param) + self.act_fn = shifted_relu + elif config.hidden_act == "fatrelu": + def fat_relu(x): + new_x = torch.zeros_like(x) + mask = torch.ge(x, config.hidden_act_param) + new_x[mask] = x[mask] + return new_x + self.act_fn = fat_relu + + def forward(self, x, output_mlp_activation=False): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + activation = self.act_fn(self.gate_proj(x)) + product = activation * self.up_proj(x) + down_proj = self.down_proj(product) + if output_mlp_activation: + return down_proj, activation + return down_proj, None + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None, + sp_threshold: float = 0.0 # sparisity threshold is used to select top-k heads based on the output norms + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + self.sp_threshold = sp_threshold + self.use_sparse_head_attention = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + output_attn_output: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + is_prefill = past_key_value is None or q_len > 1 # toggle this to enable/disable sparse forward + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + if output_attn_output: + attn_head_output = attn_output + else: + attn_head_output = None + + + gqa_version = 1 # 1: Group Sparsity, zero out the outputs of the non-selected groups + + # **Head-wise Sparsity Implementation Starts Here ** + if self.use_sparse_head_attention and not is_prefill: + use_mha_branch = self.num_key_value_heads == self.num_heads + # use_mha_branch = True + if use_mha_branch: # MHA , Head Sparsity + # Determine the number of heads to keep + k = int(self.sp_threshold * self.num_heads) + # k = max(k, 1) # Ensure at least one head is kept + + norms = attn_output.norm(dim=-1) # Shape: (batch_size, seq_len, num_heads) + + # Step 2: For each token, select the top k heads based on the norms + # Use torch.topk to get the indices of the top k heads + _, topk_indices = norms.topk(k, dim=-1) # Shapes: (batch_size, seq_len, k) + + # Step 3: Mask out all non-selected heads + # Create a mask initialized to False (zeros) + mask = torch.zeros_like(norms, dtype=torch.bool) # Shape: (batch_size, seq_len, num_heads) + + # Use scatter_ to set True at the positions of the top k heads + mask.scatter_(-1, topk_indices, True) # Now mask has True at top k head positions + + # Expand the mask to match the dimensions of attn_output + mask = mask.unsqueeze(-1) # Shape: (batch_size, seq_len, num_heads, 1) + + # Apply the mask to the attention output + # This will zero out the outputs of the non-selected heads + attn_output = attn_output * mask # Element-wise multiplication + + else: # GQA, Group Sparsity + # 1. Calculate the norm of each head + norms = attn_output.norm(dim=-1) # Shape: (batch_size, seq_len, num_heads) + + # 2. Reshape norms into groups and compute average norm for each group + norms_grouped = norms.view(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups) # Shape: (batch_size, seq_len, groups, num_heads_per_group) + #group_avg_norm = norms_grouped.mean(dim=-1) # Shape: (batch_size, seq_len, groups) + + # 3. Compute the max norm per group (resulting shape: [bsz, q_len, self.num_key_value_heads]) + group_avg_norm = norms_grouped.max(dim=-1).values + + # 3. Select the top-k groups based on the average norm of each group + # k = int(self.sp_threshold * self.num_key_value_heads) + k = math.ceil(self.sp_threshold * self.num_key_value_heads) + _, topk_indices = torch.topk(group_avg_norm, k=k, dim=-1) + + if gqa_version == 1: + # version 1 : zero out the outputs of the non-selected groups + + # 4. Build masks for the selected groups + mask_groups = torch.zeros_like(group_avg_norm) + mask_groups.scatter_(dim=-1, index=topk_indices, value=1.0) + + # 5. Expand the mask to the shape of norms + attn_output_grouped = attn_output.view(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups, self.head_dim) + mask_groups = mask_groups.unsqueeze(-1).unsqueeze(-1) # Shape: (batch_size, seq_len, groups, 1, 1) + + # 6. Apply the mask to the attention output + attn_out_masked = attn_output_grouped * mask_groups + attn_output = attn_out_masked.view(bsz, q_len, self.num_heads, self.head_dim) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, attn_head_output + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int, sp_threshold: float = 0.0): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx, sp_threshold=sp_threshold) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + output_mlp_activation: Optional[bool] = False, + output_attn_output: Optional[bool] = False, + output_router_inputs: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + if output_router_inputs: + router_input = hidden_states.detach().cpu() # this moves the tensor to cpu + + # Self Attention + hidden_states, self_attn_weights, present_key_value, attn_head_output = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + output_attn_output=output_attn_output, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + if output_attn_output: + attn_head_output = attn_head_output.detach().cpu() # this moves the tensor to cpu + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, mlp_activation = self.mlp(hidden_states, output_mlp_activation=output_mlp_activation) + + if output_mlp_activation: + mlp_activation = mlp_activation.detach().cpu() + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_mlp_activation: + outputs += (mlp_activation,) + + if output_attn_output: + outputs += (attn_head_output,) + + if output_router_inputs: + outputs += (router_input,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, sp_thresholds = None): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if sp_thresholds is None: + print("Sparsity thresholds not provided. Using dense attention.") + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + else: + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx, sp_threshold = sp_thresholds[layer_idx]) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_mlp_activations = () if output_mlp_activation else None + all_attn_outputs = () if output_attn_output else None + all_router_inputs = () if output_router_inputs else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + index = 1 + if output_attentions: + all_self_attns += (layer_outputs[index],) + index += 1 + if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + next_decoder_cache = layer_outputs[index] + index += 1 + + if output_mlp_activation: + # calculate the correct index depending on the use_cache, output_attentions + all_mlp_activations += (layer_outputs[index],) + index += 1 + + if output_attn_output: + all_attn_outputs += (layer_outputs[index],) + index += 1 + + if output_router_inputs: + # assert index < len(layer_outputs), "router_input index out of range" + all_router_inputs += (layer_outputs[index],) + index += 1 + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_mlp_activations] if v is not None) + + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=next_cache, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # ) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mlp_activations=all_mlp_activations, + attn_outputs = all_attn_outputs, + router_inputs = all_router_inputs, + ) + # if output_mlp_activation: + # output.mlp_activations = all_mlp_activations + + # if output_attn_output: + # output.attn_outputs = all_attn_outputs + + return output + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class SparseLlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, sp_thresholds = None): + super().__init__(config) + self.model = LlamaModel(config, sp_thresholds = sp_thresholds) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + result = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + mlp_activations=outputs.mlp_activations, + attn_outputs=outputs.attn_outputs, + router_inputs=outputs.router_inputs, + ) + # if output_mlp_activation: + # result.mlp_activations = outputs.mlp_activations + + return result + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if inputs_embeds is not None: + batch_size, sequence_length = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/hf_models/llama/modeling_sparse_llama_routers.py b/hf_models/llama/modeling_sparse_llama_routers.py new file mode 100644 index 0000000000000000000000000000000000000000..e095428cf7b8f8987b24bb9be95267dcfec3c606 --- /dev/null +++ b/hf_models/llama/modeling_sparse_llama_routers.py @@ -0,0 +1,1865 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + # BaseModelOutputWithPast, + # CausalLMOutputWithPast, + # QuestionAnsweringModelOutput, + # SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) + +from .modeling_outputs import ExtendedBaseModelOutputWithPast as BaseModelOutputWithPast +from .modeling_outputs import ExtendedCausalLMOutputWithPast as CausalLMOutputWithPast +from .modeling_outputs import ExtendedQuestionAnsweringModelOutput as QuestionAnsweringModelOutput +from .modeling_outputs import ExtendedSequenceClassifierOutputWithPast as SequenceClassifierOutputWithPast + + +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_llama import LlamaConfig +from HybridTensor.modules.SelectiveMHA import MHARouter + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." + ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." + ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + # self.act_fn = ACT2FN[config.hidden_act] + if config.hidden_act in ACT2FN: + self.act_fn = ACT2FN[config.hidden_act] + elif config.hidden_act == "shiftrelu": + def shifted_relu(x): + return torch.nn.functional.relu(x - config.hidden_act_param) + self.act_fn = shifted_relu + elif config.hidden_act == "fatrelu": + def fat_relu(x): + new_x = torch.zeros_like(x) + mask = torch.ge(x, config.hidden_act_param) + new_x[mask] = x[mask] + return new_x + self.act_fn = fat_relu + + def forward(self, x, output_mlp_activation=False): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + activation = self.act_fn(self.gate_proj(x)) + product = activation * self.up_proj(x) + down_proj = self.down_proj(product) + if output_mlp_activation: + return down_proj, activation + return down_proj, None + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None, + sp_threshold: float = 0.0 # sparisity threshold is used to select top-k heads based on the output norms + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + self.sp_threshold = sp_threshold + self.use_sparse_head_attention = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + output_attn_output: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + attn_router_output = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + if output_attn_output: + attn_head_output = attn_output + else: + attn_head_output = None + + + gqa_version = 1 # 1: Group Sparsity, zero out the outputs of the non-selected groups, + #2: Group Sparsity, the non activated heads get processed with the activated groups + + # **Head-wise Sparsity Implementation Starts Here ** + if self.use_sparse_head_attention: + use_mha_branch = self.num_key_value_heads == self.num_heads + # use_mha_branch = True + if use_mha_branch: # MHA , Head Sparsity + # Determine the number of heads to keep + k = int(self.sp_threshold * self.num_heads) + # k = max(k, 1) # Ensure at least one head is kept + + # norms = attn_output.norm(dim=-1) # Shape: (batch_size, seq_len, num_heads) + norms = attn_router_output + + # Step 2: For each token, select the top k heads based on the norms + # Use torch.topk to get the indices of the top k heads + _, topk_indices = norms.topk(k, dim=-1) # Shapes: (batch_size, seq_len, k) + + # Step 3: Mask out all non-selected heads + # Create a mask initialized to False (zeros) + mask = torch.zeros_like(norms, dtype=torch.bool) # Shape: (batch_size, seq_len, num_heads) + + # Use scatter_ to set True at the positions of the top k heads + mask.scatter_(-1, topk_indices, True) # Now mask has True at top k head positions + + # Expand the mask to match the dimensions of attn_output + mask = mask.unsqueeze(-1) # Shape: (batch_size, seq_len, num_heads, 1) + + # Apply the mask to the attention output + # This will zero out the outputs of the non-selected heads + attn_output = attn_output * mask # Element-wise multiplication + + else: # GQA, Group Sparsity + # 1. Calculate the norm of each head + # norms = attn_output.norm(dim=-1) # Shape: (batch_size, seq_len, num_heads) + norms = attn_router_output + + # 2. Reshape norms into groups and compute average norm for each group + norms_grouped = norms.view(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups) # Shape: (batch_size, seq_len, groups, num_heads_per_group) + group_avg_norm = norms_grouped.mean(dim=-1) # Shape: (batch_size, seq_len, groups) + + # 3. Compute the max norm per group (resulting shape: [bsz, q_len, self.num_key_value_heads]) + group_avg_norm = norms_grouped.max(dim=-1).values + + # 3. Select the top-k groups based on the average norm of each group + k = int(self.sp_threshold * self.num_key_value_heads) + _, topk_indices = torch.topk(group_avg_norm, k=k, dim=-1) + + if gqa_version == 1: + # version 1 : zero out the outputs of the non-selected groups + + # 4. Build masks for the selected groups + mask_groups = torch.zeros_like(group_avg_norm) + mask_groups.scatter_(dim=-1, index=topk_indices, value=1.0) + + # 5. Expand the mask to the shape of norms + attn_output_grouped = attn_output.view(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups, self.head_dim) + mask_groups = mask_groups.unsqueeze(-1).unsqueeze(-1) # Shape: (batch_size, seq_len, groups, 1, 1) + + # 6. Apply the mask to the attention output + attn_out_masked = attn_output_grouped * mask_groups + attn_output = attn_out_masked.view(bsz, q_len, self.num_heads, self.head_dim) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, attn_head_output + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int, sp_threshold: float = 0.0): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx, sp_threshold=sp_threshold) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.attn_router = MHARouter( + embed_dim = config.hidden_size, + out_dim = self.self_attn.num_heads, + top_k= sp_threshold, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + output_mlp_activation: Optional[bool] = False, + output_attn_output: Optional[bool] = False, + output_router_inputs: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + if output_router_inputs: + router_input = hidden_states.detach().cpu() # this moves the tensor to cpu + + # --- Router for Self Attention --- + attn_router_output = self.attn_router(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value, attn_head_output = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + output_attn_output=output_attn_output, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + attn_router_output=attn_router_output, + **kwargs, + ) + hidden_states = residual + hidden_states + + if output_attn_output: + attn_head_output = attn_head_output.detach().cpu() # this moves the tensor to cpu + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, mlp_activation = self.mlp(hidden_states, output_mlp_activation=output_mlp_activation) + + if output_mlp_activation: + mlp_activation = mlp_activation.detach().cpu() + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_mlp_activation: + outputs += (mlp_activation,) + + if output_attn_output: + outputs += (attn_head_output,) + + if output_router_inputs: + outputs += (router_input,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, sp_thresholds = None): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if sp_thresholds is None: + print("Sparsity thresholds not provided. Using dense attention.") + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + else: + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx, sp_threshold = sp_thresholds[layer_idx]) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_mlp_activations = () if output_mlp_activation else None + all_attn_outputs = () if output_attn_output else None + all_router_inputs = () if output_router_inputs else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + index = 1 + if output_attentions: + all_self_attns += (layer_outputs[index],) + index += 1 + if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + next_decoder_cache = layer_outputs[index] + index += 1 + + if output_mlp_activation: + # calculate the correct index depending on the use_cache, output_attentions + all_mlp_activations += (layer_outputs[index],) + index += 1 + + if output_attn_output: + all_attn_outputs += (layer_outputs[index],) + index += 1 + + if output_router_inputs: + # assert index < len(layer_outputs), "router_input index out of range" + all_router_inputs += (layer_outputs[index],) + index += 1 + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_mlp_activations] if v is not None) + + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=next_cache, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # ) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mlp_activations=all_mlp_activations, + attn_outputs = all_attn_outputs, + router_inputs = all_router_inputs, + ) + # if output_mlp_activation: + # output.mlp_activations = all_mlp_activations + + # if output_attn_output: + # output.attn_outputs = all_attn_outputs + + return output + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class SparseLlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, sp_thresholds = None): + super().__init__(config) + self.model = LlamaModel(config, sp_thresholds = sp_thresholds) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + result = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + mlp_activations=outputs.mlp_activations, + attn_outputs=outputs.attn_outputs, + router_inputs=outputs.router_inputs, + ) + # if output_mlp_activation: + # result.mlp_activations = outputs.mlp_activations + + return result + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if inputs_embeds is not None: + batch_size, sequence_length = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +import os +import re +from collections import OrderedDict + +def create_hf_attn_router_state_dict(router_files_dir): + """ + Loads attn_router state dicts from files and remaps keys to: + "model.decoder.layers.{layer_num}.mha_router.{param_name}" + """ + pattern = re.compile(r'attn_router_(\d+)-[\d.]+-[\d.]+\.pt$') + state_dict = OrderedDict() + loaded_layers = set() + + try: + files = os.listdir(router_files_dir) + except FileNotFoundError: + print(f"Error: Directory '{router_files_dir}' does not exist.") + return None + + router_files = [f for f in files if pattern.match(f)] + if not router_files: + print(f"No attn_router files found in '{router_files_dir}'.") + return None + + for file_name in sorted(router_files, key=lambda x: int(pattern.match(x).group(1))): + match = pattern.match(file_name) + if not match: + continue + layer_num = int(match.group(1)) + if layer_num in loaded_layers: + print(f"Warning: Duplicate router for layer {layer_num} in '{file_name}'. Skipping.") + continue + file_path = os.path.join(router_files_dir, file_name) + try: + router_weights = torch.load(file_path, map_location='cpu') + if not isinstance(router_weights, dict): + print(f"Warning: {file_path} does not contain a state dict. Skipping.") + continue + except Exception as e: + print(f"Error loading {file_path}: {e}") + continue + + for param_name, tensor in router_weights.items(): + key = f"model.layers.{layer_num}.attn_router.{param_name}" + state_dict[key] = tensor + + loaded_layers.add(layer_num) + + print(f"Loaded mha routers for {len(loaded_layers)} layers.") + return state_dict \ No newline at end of file diff --git a/hf_models/llama/tokenization_llama.py b/hf_models/llama/tokenization_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..3057665ff3087c5cfd9ce816bb2cfbbbaa5cf30c --- /dev/null +++ b/hf_models/llama/tokenization_llama.py @@ -0,0 +1,412 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for LLaMA.""" + +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from transformers.convert_slow_tokenizer import import_protobuf +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + + +if TYPE_CHECKING: + from transformers.tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +SPIECE_UNDERLINE = "▁" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class LlamaTokenizer(PreTrainedTokenizer): + """ + Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Llama should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + legacy (`bool`, *optional*): + Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. + Make sure to also set `from_slow` to `True`. + A simple example: + + - `legacy=True`: + ```python + >>> from transformers import LlamaTokenizerFast + + >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True) + >>> tokenizer.encode("Hello .") # 869 is '▁.' + [1, 15043, 29871, 1, 869] + ``` + - `legacy=False`: + ```python + >>> from transformers import LlamaTokenizerFast + + >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_slow=True) + >>> tokenizer.encode("Hello .") # 29889 is '.' + [1, 15043, 29871, 1, 29889] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. Again, this should be set with `from_slow=True` to make sure it's taken into account. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + legacy=None, + add_prefix_space=True, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behaviour of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file" + " you can ignore this message" + ) + legacy = True + + self.legacy = legacy + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + self.add_prefix_space = add_prefix_space + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0 and self.legacy: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE): + out_string += " " + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output diff --git a/hf_models/llama/tokenization_llama_fast.py b/hf_models/llama/tokenization_llama_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..8576cbef40baa9c36f4607ca4ec3f5c0abab4950 --- /dev/null +++ b/hf_models/llama/tokenization_llama_fast.py @@ -0,0 +1,255 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from shutil import copyfile +from typing import Optional, Tuple + +from tokenizers import processors + +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from transformers.utils import is_sentencepiece_available, logging +from transformers.utils.versions import require_version + + +require_version("tokenizers>=0.13.3") + +if is_sentencepiece_available(): + from .tokenization_llama import LlamaTokenizer +else: + LlamaTokenizer = None + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"} + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class LlamaTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. + + This uses notably ByteFallback and no normalization. + + ```python + >>> from transformers import LlamaTokenizerFast + + >>> tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") + >>> tokenizer.encode("Hello this is a test") + [1, 15043, 445, 338, 263, 1243] + ``` + + If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or + call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the + values of the first token and final token of an encoded sequence will not be correct). For more details, checkout + [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`, *optional*): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that + contains the vocabulary necessary to instantiate a tokenizer. + tokenizer_file (`str`, *optional*): + [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that + contains everything needed to load the tokenizer. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Llama should be used + legacy (`bool`, *optional*): + Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. + Make sure to also set `from_slow` to `True`. + A simple example: + + - `legacy=True`: + ```python + >>> from transformers import LlamaTokenizerFast + + >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True) + >>> tokenizer.encode("Hello .") # 869 is '▁.' + [1, 15043, 29871, 1, 869] + ``` + - `legacy=False`: + ```python + >>> from transformers import LlamaTokenizerFast + + >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_slow=True) + >>> tokenizer.encode("Hello .") # 29889 is '.' + [1, 15043, 29871, 1, 29889] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*): + Whether or not the tokenizer should automatically add a prefix space + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = LlamaTokenizer + padding_side = "left" + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + clean_up_tokenization_spaces=False, + unk_token="", + bos_token="", + eos_token="", + add_bos_token=True, + add_eos_token=False, + use_default_system_prompt=False, + legacy=None, + add_prefix_space=None, + **kwargs, + ): + if legacy is None: + logger.warning_once( + f"You are using the default legacy behaviour of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file" + " you can ignore this message." + ) + legacy = True + self.legacy = legacy + + if add_prefix_space is not None: + kwargs["from_slow"] = True + + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + use_default_system_prompt=use_default_system_prompt, + add_prefix_space=add_prefix_space, + legacy=legacy, + **kwargs, + ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self.update_post_processor() + self.use_default_system_prompt = use_default_system_prompt + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def update_post_processor(self): + """ + Updates the underlying post processor with the current `bos_token` and `eos_token`. + """ + bos = self.bos_token + bos_token_id = self.bos_token_id + if bos is None and self.add_bos_token: + raise ValueError("add_bos_token = True but bos_token = None") + + eos = self.eos_token + eos_token_id = self.eos_token_id + if eos is None and self.add_eos_token: + raise ValueError("add_eos_token = True but eos_token = None") + + single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" + pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output diff --git a/hf_models/opt/__init__.py b/hf_models/opt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b6315a6599cce84793674718717b7589bb2711c --- /dev/null +++ b/hf_models/opt/__init__.py @@ -0,0 +1,99 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from transformers.utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_tf_available, + is_tokenizers_available, + is_torch_available, +) + + +_import_structure = {"configuration_opt": ["OPTConfig"]} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_opt"] = [ + "OPTForCausalLM", + "OPTModel", + "OPTPreTrainedModel", + "OPTForSequenceClassification", + "OPTForQuestionAnswering", + ] + +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_opt"] = ["TFOPTForCausalLM", "TFOPTModel", "TFOPTPreTrainedModel"] + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_opt"] = [ + "FlaxOPTForCausalLM", + "FlaxOPTModel", + "FlaxOPTPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_opt import OPTConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_opt import ( + OPTForCausalLM, + OPTForQuestionAnswering, + OPTForSequenceClassification, + OPTModel, + OPTPreTrainedModel, + ) + + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/hf_models/opt/configuration_opt.py b/hf_models/opt/configuration_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..c92659eea2f6fa2e6b7bda22e1aa161ec9bfb609 --- /dev/null +++ b/hf_models/opt/configuration_opt.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2022 The Metaseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OPT model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class OPTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OPTModel`]. It is used to instantiate a OPT model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the OPT + [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50272): + Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`OPTModel`] + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + do_layer_norm_before (`bool`, *optional*, defaults to `True`): + Whether to perform layer normalization before the attention block. + word_embed_proj_dim (`int`, *optional*): + `word_embed_proj_dim` can be set to down-project word embeddings, *e.g.* `opt-350m`. Defaults to + `hidden_size`. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more + details. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + enable_bias (`bool`, *optional*, defaults to `True`): + Whether or not if the linear layers in the attention blocks should use the bias term. + layer_norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether or not if the layer norms should have learnable parameters. + + Example: + + ```python + >>> from transformers import OPTConfig, OPTModel + + >>> # Initializing a OPT facebook/opt-large style configuration + >>> configuration = OPTConfig() + + >>> # Initializing a model (with random weights) from the facebook/opt-large style configuration + >>> model = OPTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "opt" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50272, + hidden_size=768, + num_hidden_layers=12, + ffn_dim=3072, + max_position_embeddings=2048, + do_layer_norm_before=True, + _remove_final_layer_norm=False, + word_embed_proj_dim=None, + dropout=0.1, + attention_dropout=0.0, + num_attention_heads=12, + activation_function="relu", + layerdrop=0.0, + init_std=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=2, + eos_token_id=2, + enable_bias=True, + layer_norm_elementwise_affine=True, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.num_attention_heads = num_attention_heads + self.word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size + self.ffn_dim = ffn_dim + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.init_std = init_std + self.layerdrop = layerdrop + self.use_cache = use_cache + self.do_layer_norm_before = do_layer_norm_before + # We keep these variables at `True` for backward compatibility. + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + + # Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + self._remove_final_layer_norm = _remove_final_layer_norm diff --git a/hf_models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py b/hf_models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..486b477f973f3530e24a4a7a95b96358254fce1f --- /dev/null +++ b/hf_models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,113 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OPT checkpoint.""" + +import argparse +from pathlib import Path + +import torch + +from transformers import OPTConfig, OPTModel +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def load_checkpoint(checkpoint_path): + """Checkpoint path should end in model.pt""" + sd = torch.load(checkpoint_path, map_location="cpu") + if "model" in sd.keys(): + sd = torch.load(checkpoint_path, map_location="cpu")["model"] + + # pop unnecessary weights + keys_to_delete = [ + "decoder.version", + "decoder.output_projection.weight", + ] + for key in keys_to_delete: + if key in sd: + sd.pop(key) + + keys_to_rename = { + "decoder.project_in_dim.weight": "decoder.project_in.weight", + "decoder.project_out_dim.weight": "decoder.project_out.weight", + "decoder.layer_norm.weight": "decoder.final_layer_norm.weight", + "decoder.layer_norm.bias": "decoder.final_layer_norm.bias", + } + for old_key, new_key in keys_to_rename.items(): + if old_key in sd: + sd[new_key] = sd.pop(old_key) + + keys = list(sd.keys()) + for key in keys: + if ".qkv_proj." in key: + value = sd[key] + # We split QKV in separate Q,K,V + + q_name = key.replace(".qkv_proj.", ".q_proj.") + k_name = key.replace(".qkv_proj.", ".k_proj.") + v_name = key.replace(".qkv_proj.", ".v_proj.") + + depth = value.shape[0] + assert depth % 3 == 0 + # `SequeuceParallelTransformerBlock` has QKV weight is separated in K,V,Q despite the naming: + # https://cs.github.com/facebookresearch/metaseq/blob/51871bd73cd04c038f239ea2a26db1d7f6b37927/metaseq/modules/sequence_parallel_transformer_layer.py#L97 + k, v, q = torch.split(value, depth // 3, dim=0) + + sd[q_name] = q + sd[k_name] = k + sd[v_name] = v + del sd[key] + + return sd + + +@torch.no_grad() +def convert_opt_checkpoint(checkpoint_path, pytorch_dump_folder_path, config=None): + """ + Copy/paste/tweak model's weights to our BERT structure. + """ + state_dict = load_checkpoint(checkpoint_path) + + if config is not None: + config = OPTConfig.from_pretrained(config) + else: + config = OPTConfig() + + model = OPTModel(config).half().eval() + model.load_state_dict(state_dict) + + # Check results + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--fairseq_path", + type=str, + help=( + "path to fairseq checkpoint in correct format. You can find all checkpoints in the correct format here:" + " https://huggingface.co/models?other=opt_metasq" + ), + ) + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--hf_config", default=None, type=str, help="Define HF config.") + args = parser.parse_args() + convert_opt_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, config=args.hf_config) diff --git a/hf_models/opt/help.txt b/hf_models/opt/help.txt new file mode 100644 index 0000000000000000000000000000000000000000..06ab98956cffd372b17a8dc0e6442883d1e8bcff --- /dev/null +++ b/hf_models/opt/help.txt @@ -0,0 +1,8 @@ +modeling_opt: contains code to return mlp_activations +modeling_sparse_opt: contains code to simulate sparse attention based on threshold +modeling_sparse_opt_topk: contains code to simulate sparse attention based on top k activation +modling_opt_router: contains code to return mlp_activations, attention output norms, router inputs +modling_opt_attn: contains code to return input and output of the attention layer to understand importance + use output_router_inputs: input to attention layer + use output_attn_output: output of attention layer after adding residual + \ No newline at end of file diff --git a/hf_models/opt/modeling_flax_opt.py b/hf_models/opt/modeling_flax_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..c6296e4eeae0014fdea5fe5bc8850e9da6680ddc --- /dev/null +++ b/hf_models/opt/modeling_flax_opt.py @@ -0,0 +1,799 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax OPT model.""" + +from functools import partial +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxMaskedLMOutput +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring +from ...utils import add_start_docstrings, logging +from .configuration_opt import OPTConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + + +OPT_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`OPTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->OPT +class FlaxOPTAttention(nn.Module): + config: OPTConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + # self_attention + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class FlaxOPTDecoderLayer(nn.Module): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.hidden_size + self.self_attn = FlaxOPTAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.num_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + dtype=self.dtype, + ) + self.do_layer_norm_before = self.config.do_layer_norm_before + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.fc1 = nn.Dense( + self.config.ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + init_cache=init_cache, + deterministic=deterministic, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + hidden_states = (residual + hidden_states).reshape(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class FlaxOPTDecoderLayerCollection(nn.Module): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxOPTDecoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + self.layerdrop = self.config.layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + outputs = [hidden_states, all_hidden_states, all_self_attns] + return outputs + + +class FlaxOPTLearnedPositionalEmbedding(nn.Embed): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def setup(self): + self.offset = 2 + self.embedding = self.param( + "embedding", self.embedding_init, (self.num_embeddings + self.offset, self.features), self.param_dtype + ) + + def __call__(self, positions): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + return super().__call__(positions + self.offset) + + +class FlaxOPTDecoder(nn.Module): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + offset: int = 2 + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.hidden_size + self.padding_idx = self.config.pad_token_id + self.max_target_positions = self.config.max_position_embeddings + + self.embed_tokens = nn.Embed( + self.config.vocab_size, + self.config.word_embed_proj_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + self.embed_positions = FlaxOPTLearnedPositionalEmbedding( + self.config.max_position_embeddings, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + if self.config.word_embed_proj_dim != self.config.hidden_size: + self.project_in = nn.Dense(self.config.hidden_size, use_bias=False) + self.project_out = nn.Dense(self.config.word_embed_proj_dim, use_bias=False) + + else: + self.project_in = None + self.project_out = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + else: + self.final_layer_norm = None + + self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + positions = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + positions + + hidden_state, all_hidden_states, attentions = self.layers( + hidden_states, + attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if self.final_layer_norm is not None: + hidden_state = self.final_layer_norm(hidden_state) + + if self.project_out is not None: + hidden_state = self.project_out(hidden_state) + + if output_hidden_states: + all_hidden_states += (hidden_state,) + + outputs = [hidden_state, all_hidden_states, attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_state, + hidden_states=all_hidden_states, + attentions=attentions, + ) + + +class FlaxOPTPreTrainedModel(FlaxPreTrainedModel): + config_class = OPTConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: OPTConfig, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + position_ids, + return_dict=False, + ) + + random_params = module_init_outputs["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + params: dict = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + dropout_rng: PRNGKey = None, + deterministic: bool = True, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if position_ids is None: + position_ids = (attention_mask.cumsum(axis=1) * attention_mask) - 1 + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxOPTAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxOPTModule(nn.Module): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.decoder = FlaxOPTDecoder(self.config, dtype=self.dtype) + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + init_cache=False, + ): + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + init_cache=init_cache, + ) + + if not return_dict: + return decoder_outputs + + return FlaxBaseModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModel with Bart->OPT +class FlaxOPTModel(FlaxOPTPreTrainedModel): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxOPTModule + + +append_call_sample_docstring(FlaxOPTModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) + + +@add_start_docstrings( + "The bare OPT Model transformer outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class FlaxOPTForCausalLMModule(nn.Module): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.model = FlaxOPTModule(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids, + attention_mask, + position_ids, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=lm_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + OPT Model with a language modeling head on top (linear layer with weights tied to the input embeddings) e.g for + autoregressive tasks. + """, + OPT_START_DOCSTRING, +) +class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel): + module_class = FlaxOPTForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxOPTForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxBaseModelOutput, + _CONFIG_FOR_DOC, +) diff --git a/hf_models/opt/modeling_opt.py b/hf_models/opt/modeling_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..55b34ddb64f95251f05a78354c010e10663ab590 --- /dev/null +++ b/hf_models/opt/modeling_opt.py @@ -0,0 +1,1811 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OPT model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + +''' +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +''' + +from .modeling_outputs import ExtendedBaseModelOutputWithPast as BaseModelOutputWithPast +from .modeling_outputs import ExtendedCausalLMOutputWithPast as CausalLMOutputWithPast +from .modeling_outputs import ExtendedQuestionAnsweringModelOutput as QuestionAnsweringModelOutput +from .modeling_outputs import ExtendedSequenceClassifierOutputWithPast as SequenceClassifierOutputWithPast + +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_opt import OPTConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" +_SEQ_CLASS_EXPECTED_LOSS = 1.71 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: OPTConfig, + is_decoder: bool = False, + activation_threshold: float = 0.0, # Threshold for head-wise sparsity + **kwargs, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.enable_bias = config.enable_bias + + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + + self.activation_threshold = activation_threshold + if activation_threshold > 0.0: + self.use_sparse_head_attention = True + + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + + # we need to return the attn_output and observe the vector norm of each head + # ... + attn_head_output = attn_output + + # apply sparse head activation here, mask the hidden states of heads whose norm is below a certain threshold + + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + def sparse_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Forward pass with head-wise sparsity based on output norms.""" + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_head_output = attn_output + + if self.use_sparse_head_attention: + # **Head-wise Sparsity Implementation Starts Here** + + # Compute the L2 norm for each head's output across the head_dim for each token + # Shape: (batch_size, num_heads, tgt_len) + head_norms = attn_output.norm(p=2, dim=-1) # L2 norm + + # Create a mask where heads with norm >= threshold are kept + # Shape: (batch_size, num_heads, tgt_len) + head_mask = head_norms >= self.activation_threshold + + # Expand the mask to match the attn_output dimensions + # Shape: (batch_size, num_heads, tgt_len, 1) + head_mask = head_mask.unsqueeze(-1).type_as(attn_output) + + # Apply the mask to zero out the heads below the threshold for each token + attn_output = attn_output * head_mask # Zero out inactive heads + + # **Optional: Store the mask for analysis or further processing** + # self.current_head_mask = head_mask # Uncomment if you want to store the mask + + # Recompute hidden_states from the masked attn_output + # Permute to (batch_size, tgt_len, num_heads, head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + + # Combine heads by reshaping + # Shape: (batch_size, tgt_len, num_heads * head_dim) + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + + + + # Apply the final linear projection + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + +class OptFlashAttention2(OPTAttention): + """ + OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of flash + attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, _, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout + ) + + # need to return attn_output and the vector norm of each head + attn_head_output = attn_output + + # apply sparse head activation here, mask the hidden states of heads whose norm is below a certain threshold + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + + def sparse_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, _, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout + ) + + # need to return attn_output and the vector norm of each head + attn_head_output = attn_output + + # apply sparse head activation here, mask the hidden states of heads whose norm is below a certain threshold + if self.use_sparse_head_attention: + # **Head-wise Sparsity Implementation Starts Here** + + # Compute the L2 norm for each head's output across the head_dim for each token + # Shape: (batch_size, num_heads, tgt_len) + head_norms = attn_output.norm(p=2, dim=-1) # L2 norm + + # Create a mask where heads with norm >= threshold are kept + # Shape: (batch_size, num_heads, tgt_len) + head_mask = head_norms >= self.activation_threshold + + # Expand the mask to match the attn_output dimensions + # Shape: (batch_size, num_heads, tgt_len, 1) + head_mask = head_mask.unsqueeze(-1).type_as(attn_output) + + # Apply the mask to zero out the heads below the threshold for each token + attn_output = attn_output * head_mask # Zero out inactive heads + + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +OPT_ATTENTION_CLASSES = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, +} + + +class OPTDecoderLayer(nn.Module): + def __init__(self, config: OPTConfig): + super().__init__() + self.embed_dim = config.hidden_size + + self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True) + + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + output_mlp_activation: Optional[bool] = False, + output_attn_output: Optional[bool] = False, + output_router_inputs: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # need to record the hidden_states after layer norm as input to the router + # record hidden_states + if output_router_inputs: + router_input = hidden_states.detach().cpu() # this moves the tensor to cpu + + # Self Attention + hidden_states, self_attn_weights, present_key_value, attn_head_output = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # need to change this to record the activation output of each head + if output_attn_output: + attn_output = attn_head_output.detach().cpu() # this moves the tensor to cpu + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + # record activation output here + if output_mlp_activation: + mlp_hidden_states = hidden_states.detach().cpu() # this moves the tensor to cpu + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + # for data collection for router, it might be beneficial to record the hidden_states after layer norm + outputs = (hidden_states,) # index 0 + + if output_attentions: + outputs += (self_attn_weights,) # always index 1 + + if use_cache: + outputs += (present_key_value,) # can be 1 or 2 + + if output_mlp_activation: + outputs += (mlp_hidden_states,) # can be 1, 2 or 3 + + if output_attn_output: + outputs += (attn_output,) # can be 1, 2, 3 or 4 + + if output_router_inputs: + outputs += (router_input,) # can be 1, 2, 3, 4 or 5 + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OPTDecoderLayer"] + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) + self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine + ) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if self._use_flash_attention_2: + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask + ) + else: + # 4d mask is passed through the layers + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + all_mlp_activations = () if output_mlp_activation else None + all_attn_outputs = () if output_attn_output else None + all_router_inputs = () if output_router_inputs else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + ) + + # record states + ''' + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_mlp_activation: + all_mlp_activations += (layer_outputs[-1],) + + if output_attn_output: + all_attn_outputs += (layer_outputs[-1],) + + ''' + + hidden_states = layer_outputs[0] + + current_index = 1 # Starts after hidden_states + + if output_attentions: + assert current_index < len(layer_outputs), "self_attn_weights index out of range" + all_self_attns += (layer_outputs[current_index],) + current_index += 1 + + if use_cache: + assert current_index < len(layer_outputs), "present_key_value index out of range" + next_decoder_cache += (layer_outputs[current_index],) + current_index += 1 + + if output_mlp_activation: + assert current_index < len(layer_outputs), "mlp_hidden_states index out of range" + all_mlp_activations += (layer_outputs[current_index],) + current_index += 1 + + if output_attn_output: + assert current_index < len(layer_outputs), "attn_output index out of range" + all_attn_outputs += (layer_outputs[current_index],) + current_index += 1 + + if output_router_inputs: + assert current_index < len(layer_outputs), "router_input index out of range" + all_router_inputs += (layer_outputs[current_index],) + current_index += 1 + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mlp_activations=all_mlp_activations, + attn_outputs = all_attn_outputs, + router_inputs = all_router_inputs, + ) + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.decoder = OPTDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activations: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activations, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + mlp_activations=decoder_outputs.mlp_activations, + attn_outputs = decoder_outputs.attn_outputs, + router_inputs = decoder_outputs.router_inputs, + ) + + +class OPTForCausalLM(OPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = OPTModel(config) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + mlp_activations=outputs.mlp_activations, + attn_outputs=outputs.attn_outputs, + router_inputs=outputs.router_inputs, + ) + + + + + +@add_start_docstrings( + """ + The OPT Model transformer with a sequence classification head on top (linear layer). + + [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPT_START_DOCSTRING, +) +class OPTForSequenceClassification(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OPTModel(config) + self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + +@add_start_docstrings( + """ + The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + OPT_START_DOCSTRING, +) +class OPTForQuestionAnswering(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.model = OPTModel(config) + self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForQuestionAnswering + >>> import torch + + >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> # note: we are loading a OPTForQuestionAnswering from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random + >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + + >>> inputs = tokenizer(question, text, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> answer_offset = len(tokenizer(question)[0]) + + >>> predict_answer_tokens = inputs.input_ids[ + ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 + ... ] + >>> predicted = tokenizer.decode(predict_answer_tokens) + >>> predicted + ' a nice puppet' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index).to(logits.device) + end_positions = end_positions.clamp(0, ignored_index).to(logits.device) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value diff --git a/hf_models/opt/modeling_opt_attn.py b/hf_models/opt/modeling_opt_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..58a3c01db8ac0c4bf116dd74a937d2113435e96b --- /dev/null +++ b/hf_models/opt/modeling_opt_attn.py @@ -0,0 +1,1792 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OPT model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + +''' +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +''' + +from .modeling_outputs import ExtendedBaseModelOutputWithPast as BaseModelOutputWithPast +from .modeling_outputs import ExtendedCausalLMOutputWithPast as CausalLMOutputWithPast +from .modeling_outputs import ExtendedQuestionAnsweringModelOutput as QuestionAnsweringModelOutput +from .modeling_outputs import ExtendedSequenceClassifierOutputWithPast as SequenceClassifierOutputWithPast + +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_opt import OPTConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" +_SEQ_CLASS_EXPECTED_LOSS = 1.71 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: OPTConfig, + is_decoder: bool = False, + activation_threshold: float = 0.0, # Threshold for head-wise sparsity + **kwargs, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.enable_bias = config.enable_bias + + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + + self.activation_threshold = activation_threshold + if activation_threshold > 0.0: + self.use_sparse_head_attention = True + + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + + # we need to return the attn_output and observe the vector norm of each head + # ... + attn_head_output = attn_output + + # apply sparse head activation here, mask the hidden states of heads whose norm is below a certain threshold + + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + def sparse_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Forward pass with head-wise sparsity based on output norms.""" + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_head_output = attn_output + + if self.use_sparse_head_attention: + # **Head-wise Sparsity Implementation Starts Here** + + # Compute the L2 norm for each head's output across the head_dim for each token + # Shape: (batch_size, num_heads, tgt_len) + head_norms = attn_output.norm(p=2, dim=-1) # L2 norm + + # Create a mask where heads with norm >= threshold are kept + # Shape: (batch_size, num_heads, tgt_len) + head_mask = head_norms >= self.activation_threshold + + # Expand the mask to match the attn_output dimensions + # Shape: (batch_size, num_heads, tgt_len, 1) + head_mask = head_mask.unsqueeze(-1).type_as(attn_output) + + # Apply the mask to zero out the heads below the threshold for each token + attn_output = attn_output * head_mask # Zero out inactive heads + + # **Optional: Store the mask for analysis or further processing** + # self.current_head_mask = head_mask # Uncomment if you want to store the mask + + # Recompute hidden_states from the masked attn_output + # Permute to (batch_size, tgt_len, num_heads, head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + + # Combine heads by reshaping + # Shape: (batch_size, tgt_len, num_heads * head_dim) + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + + + + # Apply the final linear projection + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + +class OptFlashAttention2(OPTAttention): + """ + OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of flash + attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, _, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout + ) + + # need to return attn_output and the vector norm of each head + attn_head_output = attn_output + + # apply sparse head activation here, mask the hidden states of heads whose norm is below a certain threshold + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + + def sparse_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, _, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout + ) + + # need to return attn_output and the vector norm of each head + attn_head_output = attn_output + + # apply sparse head activation here, mask the hidden states of heads whose norm is below a certain threshold + if self.use_sparse_head_attention: + # **Head-wise Sparsity Implementation Starts Here** + + # Compute the L2 norm for each head's output across the head_dim for each token + # Shape: (batch_size, num_heads, tgt_len) + head_norms = attn_output.norm(p=2, dim=-1) # L2 norm + + # Create a mask where heads with norm >= threshold are kept + # Shape: (batch_size, num_heads, tgt_len) + head_mask = head_norms >= self.activation_threshold + + # Expand the mask to match the attn_output dimensions + # Shape: (batch_size, num_heads, tgt_len, 1) + head_mask = head_mask.unsqueeze(-1).type_as(attn_output) + + # Apply the mask to zero out the heads below the threshold for each token + attn_output = attn_output * head_mask # Zero out inactive heads + + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +OPT_ATTENTION_CLASSES = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, +} + + +class OPTDecoderLayer(nn.Module): + def __init__(self, config: OPTConfig): + super().__init__() + self.embed_dim = config.hidden_size + + self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True) + + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + output_mlp_activation: Optional[bool] = False, + output_attn_output: Optional[bool] = False, + output_router_inputs: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # need to record the hidden_states before layer norm as input to the attention layer + # record hidden_states + if output_router_inputs: + router_input = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value, attn_head_output = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # need to change this to record the activation output of each head + if output_attn_output: + attn_output = hidden_states.detach().cpu() # this moves the tensor to cpu + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + # record activation output here + if output_mlp_activation: + mlp_hidden_states = hidden_states.detach().cpu() # this moves the tensor to cpu + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + # for data collection for router, it might be beneficial to record the hidden_states after layer norm + outputs = (hidden_states,) # index 0 + + if output_attentions: + outputs += (self_attn_weights,) # always index 1 + + if use_cache: + outputs += (present_key_value,) # can be 1 or 2 + + if output_mlp_activation: + outputs += (mlp_hidden_states,) # can be 1, 2 or 3 + + if output_attn_output: + outputs += (attn_output,) # can be 1, 2, 3 or 4 + + if output_router_inputs: + outputs += (router_input,) # can be 1, 2, 3, 4 or 5 + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OPTDecoderLayer"] + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) + self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine + ) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if self._use_flash_attention_2: + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask + ) + else: + # 4d mask is passed through the layers + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + all_mlp_activations = () if output_mlp_activation else None + all_attn_outputs = () if output_attn_output else None + all_router_inputs = () if output_router_inputs else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + ) + + # record states + + hidden_states = layer_outputs[0] + + current_index = 1 # Starts after hidden_states + + if output_attentions: + assert current_index < len(layer_outputs), "self_attn_weights index out of range" + all_self_attns += (layer_outputs[current_index],) + current_index += 1 + + if use_cache: + assert current_index < len(layer_outputs), "present_key_value index out of range" + next_decoder_cache += (layer_outputs[current_index],) + current_index += 1 + + if output_mlp_activation: + assert current_index < len(layer_outputs), "mlp_hidden_states index out of range" + all_mlp_activations += (layer_outputs[current_index],) + current_index += 1 + + if output_attn_output: + assert current_index < len(layer_outputs), "attn_output index out of range" + all_attn_outputs += (layer_outputs[current_index],) + current_index += 1 + + if output_router_inputs: + assert current_index < len(layer_outputs), "router_input index out of range" + all_router_inputs += (layer_outputs[current_index],) + current_index += 1 + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mlp_activations=all_mlp_activations, + attn_outputs = all_attn_outputs, + router_inputs = all_router_inputs, + ) + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.decoder = OPTDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activations: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activations, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + mlp_activations=decoder_outputs.mlp_activations, + attn_outputs = decoder_outputs.attn_outputs, + router_inputs = decoder_outputs.router_inputs, + ) + + +class OPTForCausalLM(OPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = OPTModel(config) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + output_router_inputs: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + output_router_inputs=output_router_inputs, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + mlp_activations=outputs.mlp_activations, + attn_outputs=outputs.attn_outputs, + router_inputs=outputs.router_inputs, + ) + + +@add_start_docstrings( + """ + The OPT Model transformer with a sequence classification head on top (linear layer). + + [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPT_START_DOCSTRING, +) +class OPTForSequenceClassification(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OPTModel(config) + self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + +@add_start_docstrings( + """ + The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + OPT_START_DOCSTRING, +) +class OPTForQuestionAnswering(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.model = OPTModel(config) + self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForQuestionAnswering + >>> import torch + + >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> # note: we are loading a OPTForQuestionAnswering from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random + >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + + >>> inputs = tokenizer(question, text, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> answer_offset = len(tokenizer(question)[0]) + + >>> predict_answer_tokens = inputs.input_ids[ + ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 + ... ] + >>> predicted = tokenizer.decode(predict_answer_tokens) + >>> predicted + ' a nice puppet' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index).to(logits.device) + end_positions = end_positions.clamp(0, ignored_index).to(logits.device) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value diff --git a/hf_models/opt/modeling_opt_routers.py b/hf_models/opt/modeling_opt_routers.py new file mode 100644 index 0000000000000000000000000000000000000000..86e0999d71482da6ab5804d471f604c605b619cc --- /dev/null +++ b/hf_models/opt/modeling_opt_routers.py @@ -0,0 +1,1712 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OPT model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + +import os +import re +from collections import OrderedDict + +''' +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +''' + +from .modeling_outputs import ExtendedBaseModelOutputWithPast as BaseModelOutputWithPast +from .modeling_outputs import ExtendedCausalLMOutputWithPast as CausalLMOutputWithPast +from .modeling_outputs import ExtendedQuestionAnsweringModelOutput as QuestionAnsweringModelOutput +from .modeling_outputs import ExtendedSequenceClassifierOutputWithPast as SequenceClassifierOutputWithPast + +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_opt import OPTConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from HybridTensor.modules.SelectiveMLP import MLPRouter +from HybridTensor.modules.SelectiveMHA import MHARouter + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" +_SEQ_CLASS_EXPECTED_LOSS = 1.71 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: OPTConfig, + is_decoder: bool = False, + sp_threshold: float = 0.0, # sparisity threshold is used to select top-k heads based on the output norms + **kwargs, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.enable_bias = config.enable_bias + + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + + # Sparsity related attributes + self.sp_threshold = sp_threshold + self.use_sparse_head_attention = True + + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + attn_router_output: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Forward pass with head-wise sparsity based on output norms.""" + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_head_output = attn_output + + # **Head-wise Sparsity Implementation Starts Here** + # Determine the number of heads to keep + k = int(self.sp_threshold * self.num_heads) + k = max(k, 1) # Ensure at least one head is kept + + # Compute the L2 norm for each head's output across the head_dim + # Shape: (batch_size, num_heads, seq_len) + head_norms = attn_head_output.norm(p=2, dim=-1) # Norm over head_dim + + # Permute to shape (batch_size, seq_len, num_heads) to select top k heads per token + head_norms_permuted = head_norms.permute(0, 2, 1) # Shape: (batch_size, seq_len, num_heads) + + # Get the indices of the top k heads for each (batch, seq_len) + _, indices = torch.topk(head_norms_permuted, k=k, dim=-1, largest=True, sorted=False) # Shape: (batch_size, seq_len, k) + + # Create a mask of zeros + head_mask = torch.zeros_like(head_norms_permuted) # Shape: (batch_size, seq_len, num_heads) + + # Scatter ones at the indices of the top k heads + head_mask.scatter_(-1, indices, 1.0) + + # Permute back to shape (batch_size, num_heads, seq_len) + head_mask = head_mask.permute(0, 2, 1) # Shape: (batch_size, num_heads, seq_len) + + # Expand the mask to match attn_output dimensions + # Shape: (batch_size, num_heads, seq_len, 1) + head_mask = head_mask.unsqueeze(-1).type_as(attn_head_output) + + # Apply the mask to zero out the heads not in the top k for each token + attn_output = attn_head_output * head_mask + + # **Head-wise Sparsity Implementation Ends Here** + + # Permute to (batch_size, tgt_len, num_heads, head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + + # Combine heads by reshaping + # Shape: (batch_size, tgt_len, num_heads * head_dim) + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + + # Apply the final linear projection + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + +class OptFlashAttention2(OPTAttention): + """ + OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of flash + attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + attn_router_output: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, seq_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout + ) + + # need to return attn_output and the vector norm of each head + attn_head_output = attn_output + + # **Head-wise Sparsity Implementation Starts Here ** + if self.use_sparse_head_attention: + # Determine the number of heads to keep + k = int(self.sp_threshold * self.num_heads) + # k = max(k, 1) # Ensure at least one head is kept + + # norms = attn_output.norm(dim=-1) # Shape: (batch_size, seq_len, num_heads) + norms = attn_router_output # Shape: (batch_size, seq_len, num_heads) + + # Step 2: For each token, select the top k heads based on the norms + # Use torch.topk to get the indices of the top k heads + topk_values, topk_indices = norms.topk(k, dim=-1) # Shapes: (batch_size, seq_len, k) + + # Step 3: Mask out all non-selected heads + # Create a mask initialized to False (zeros) + mask = torch.zeros_like(norms, dtype=torch.bool) # Shape: (batch_size, seq_len, num_heads) + + # Use scatter_ to set True at the positions of the top k heads + mask.scatter_(-1, topk_indices, True) # Now mask has True at top k head positions + + # Expand the mask to match the dimensions of attn_output + mask = mask.unsqueeze(-1) # Shape: (batch_size, seq_len, num_heads, 1) + + # dense prefill sparse decode + prefill_len = int(seq_len * 0.8) + mask[:, :prefill_len, :, :] = True + + # Apply the mask to the attention output + # This will zero out the outputs of the non-selected heads + attn_output = attn_output * mask # Element-wise multiplication + + # **Head-wise Sparsity Implementation Ends Here ** + + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +OPT_ATTENTION_CLASSES = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, +} + + +class OPTDecoderLayer(nn.Module): + def __init__(self, config: OPTConfig, attn_act_topk: float = 0.0, mlp_act: float = -1): + super().__init__() + self.embed_dim = config.hidden_size + + self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True, sp_threshold=attn_act_topk) + + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # self.sp_threshold = sp_threshold + self.attn_act_topk = attn_act_topk + self.mlp_act = mlp_act + + # --- Initialize routers (hyperparameters can be tuned as needed) --- + self.mha_router = MHARouter( + embed_dim=self.embed_dim, + low_rank_dim=self.embed_dim // 2, # low rank not used, but still need to specify + out_dim=self.self_attn.num_heads, + top_k= attn_act_topk # e.g. activate 50% of heads + ) + + self.mlp_router = MLPRouter( + embed_dim=self.embed_dim, + low_rank_dim=1024, + out_dim=config.ffn_dim, + act_th=0.5 # threshold for neuron activation + ) + + self.mlp_topk = None # use this to activate only the top k neurons in the MLP + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + output_mlp_activation: Optional[bool] = False, + output_attn_output: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # --- Router for Self Attention --- + mha_router_inputs = hidden_states + attn_router_output = self.mha_router(mha_router_inputs) + + # --- Router for MLP --- + mlp_router_inputs = hidden_states.reshape(-1, hidden_states.size(-1)) + mlp_router_output = self.mlp_router(mlp_router_inputs) + + # Self Attention + hidden_states, self_attn_weights, present_key_value, attn_head_output = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + attn_router_output = attn_router_output, + ) + + # to simulate sparse head activation, zero out the hidden states of heads whose norm is below a certain threshold + # .... + + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # need to change this to record the activation output of each head + if output_attn_output: + attn_output = attn_head_output + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + + # apply the router mask here to simulate sparse neuron activation + mlp_router_mask = mlp_router_output > self.mlp_act + hidden_states = hidden_states * mlp_router_mask + + hidden_states = self.activation_fn(hidden_states) + + # record activation output here + if output_mlp_activation: + mlp_hidden_states = hidden_states + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) # index 0 + + if output_attentions: + outputs += (self_attn_weights,) # always index 1 + + if use_cache: + outputs += (present_key_value,) # can be 1 or 2 + + if output_mlp_activation: + outputs += (mlp_hidden_states,) # can be 1, 2 or 3 + + if output_attn_output: + outputs += (attn_output,) # can be 1, 2, 3 or 4 + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OPTDecoderLayer"] + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig, sp_thresholds = None, mlp_thresholds = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) + self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine + ) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList([OPTDecoderLayer(config, attn_act_topk = sp_thresholds[i]) for i in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if self._use_flash_attention_2: + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask + ) + else: + # 4d mask is passed through the layers + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + all_mlp_activations = () if output_mlp_activation else None + all_attn_outputs = () if output_attn_output else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + + # record states + ''' + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_mlp_activation: + all_mlp_activations += (layer_outputs[-1],) + + if output_attn_output: + all_attn_outputs += (layer_outputs[-1],) + + ''' + + hidden_states = layer_outputs[0] + + # Adjust the indices based on the conditions + cache_index = 2 if output_attentions else 1 + mlp_activation_index = 3 if output_attentions and use_cache else 2 if output_attentions or use_cache else 1 + attn_output_index = mlp_activation_index + 1 if output_mlp_activation else mlp_activation_index + + if output_attentions: + all_self_attns += (layer_outputs[1],) # Assuming self-attention weights are always at index 1 + + if use_cache: + next_decoder_cache += (layer_outputs[cache_index],) + + if output_mlp_activation: + all_mlp_activations += (layer_outputs[mlp_activation_index],) + + if output_attn_output: + all_attn_outputs += (layer_outputs[attn_output_index],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mlp_activations=all_mlp_activations, + attn_outputs = all_attn_outputs, + ) + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig, sp_thresholds=None): + super().__init__(config) + self.decoder = OPTDecoder(config, sp_thresholds) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activations: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activations, + output_attn_output=output_attn_output, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + mlp_activations=decoder_outputs.mlp_activations, + attn_outputs = decoder_outputs.attn_outputs, + ) + + + + +class SparseOPTForCausalLM(OPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, sp_thresholds=None): + super().__init__(config) + self.model = OPTModel(config, sp_thresholds) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + mlp_activations=outputs.mlp_activations, + attn_outputs=outputs.attn_outputs, + ) + + + + + +@add_start_docstrings( + """ + The OPT Model transformer with a sequence classification head on top (linear layer). + + [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPT_START_DOCSTRING, +) +class OPTForSequenceClassification(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OPTModel(config) + self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + +@add_start_docstrings( + """ + The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + OPT_START_DOCSTRING, +) +class OPTForQuestionAnswering(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.model = OPTModel(config) + self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForQuestionAnswering + >>> import torch + + >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> # note: we are loading a OPTForQuestionAnswering from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random + >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + + >>> inputs = tokenizer(question, text, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> answer_offset = len(tokenizer(question)[0]) + + >>> predict_answer_tokens = inputs.input_ids[ + ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 + ... ] + >>> predicted = tokenizer.decode(predict_answer_tokens) + >>> predicted + ' a nice puppet' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index).to(logits.device) + end_positions = end_positions.clamp(0, ignored_index).to(logits.device) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + +def create_hf_mlp_router_state_dict(router_files_dir): + """ + Loads mlp_router state dicts from files and remaps keys to: + "model.decoder.layers.{layer_num}.mlp_router.{param_name}" + """ + pattern = re.compile(r'mlp_router_(\d+)-[\d.]+-[\d.]+-[\d.]+\.pt$') + state_dict = OrderedDict() + + try: + files = os.listdir(router_files_dir) + except FileNotFoundError: + print(f"Error: Directory '{router_files_dir}' does not exist.") + return None + + router_files = [f for f in files if pattern.match(f)] + if not router_files: + print(f"No mlp_router files found in '{router_files_dir}'.") + return None + + for file_name in sorted(router_files, key=lambda x: int(pattern.match(x).group(1))): + match = pattern.match(file_name) + if not match: + continue + layer_num = int(match.group(1)) + file_path = os.path.join(router_files_dir, file_name) + try: + router_weights = torch.load(file_path, map_location='cpu') + if not isinstance(router_weights, dict): + print(f"Warning: {file_path} does not contain a state dict. Skipping.") + continue + except Exception as e: + print(f"Error loading {file_path}: {e}") + continue + + for param_name, tensor in router_weights.items(): + key = f"model.decoder.layers.{layer_num}.mlp_router.{param_name}" + state_dict[key] = tensor + + print(f"Loaded mlp routers for {len(state_dict) // len(router_weights)} layers.") + return state_dict + + +def create_hf_mha_router_state_dict(router_files_dir): + """ + Loads attn_router state dicts from files and remaps keys to: + "model.decoder.layers.{layer_num}.mha_router.{param_name}" + """ + pattern = re.compile(r'attn_router_(\d+)-[\d.]+-[\d.]+\.pt$') + state_dict = OrderedDict() + loaded_layers = set() + + try: + files = os.listdir(router_files_dir) + except FileNotFoundError: + print(f"Error: Directory '{router_files_dir}' does not exist.") + return None + + router_files = [f for f in files if pattern.match(f)] + if not router_files: + print(f"No attn_router files found in '{router_files_dir}'.") + return None + + for file_name in sorted(router_files, key=lambda x: int(pattern.match(x).group(1))): + match = pattern.match(file_name) + if not match: + continue + layer_num = int(match.group(1)) + if layer_num in loaded_layers: + print(f"Warning: Duplicate router for layer {layer_num} in '{file_name}'. Skipping.") + continue + file_path = os.path.join(router_files_dir, file_name) + try: + router_weights = torch.load(file_path, map_location='cpu') + if not isinstance(router_weights, dict): + print(f"Warning: {file_path} does not contain a state dict. Skipping.") + continue + except Exception as e: + print(f"Error loading {file_path}: {e}") + continue + + for param_name, tensor in router_weights.items(): + key = f"model.decoder.layers.{layer_num}.mha_router.{param_name}" + state_dict[key] = tensor + + loaded_layers.add(layer_num) + + print(f"Loaded mha routers for {len(loaded_layers)} layers.") + return state_dict diff --git a/hf_models/opt/modeling_opt_routers_topk.py b/hf_models/opt/modeling_opt_routers_topk.py new file mode 100644 index 0000000000000000000000000000000000000000..5853a49fd7a7603d35dc48a8b4745f2d3e4593b2 --- /dev/null +++ b/hf_models/opt/modeling_opt_routers_topk.py @@ -0,0 +1,1706 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OPT model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + +import os +import re +from collections import OrderedDict + +''' +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +''' + +from .modeling_outputs import ExtendedBaseModelOutputWithPast as BaseModelOutputWithPast +from .modeling_outputs import ExtendedCausalLMOutputWithPast as CausalLMOutputWithPast +from .modeling_outputs import ExtendedQuestionAnsweringModelOutput as QuestionAnsweringModelOutput +from .modeling_outputs import ExtendedSequenceClassifierOutputWithPast as SequenceClassifierOutputWithPast + +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_opt import OPTConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from HybridTensor.modules.SelectiveMLP import MLPRouter +from HybridTensor.modules.SelectiveMHA import MHARouter + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" +_SEQ_CLASS_EXPECTED_LOSS = 1.71 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: OPTConfig, + is_decoder: bool = False, + sp_threshold: float = 0.0, # sparisity threshold is used to select top-k heads based on the output norms + **kwargs, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.enable_bias = config.enable_bias + + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + + # Sparsity related attributes + self.sp_threshold = sp_threshold + self.use_sparse_head_attention = True + + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + attn_router_output: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Forward pass with head-wise sparsity based on output norms.""" + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_head_output = attn_output + + # **Head-wise Sparsity Implementation Starts Here** + # Determine the number of heads to keep + k = int(self.sp_threshold * self.num_heads) + k = max(k, 1) # Ensure at least one head is kept + + # Compute the L2 norm for each head's output across the head_dim + # Shape: (batch_size, num_heads, seq_len) + head_norms = attn_head_output.norm(p=2, dim=-1) # Norm over head_dim + + # Permute to shape (batch_size, seq_len, num_heads) to select top k heads per token + head_norms_permuted = head_norms.permute(0, 2, 1) # Shape: (batch_size, seq_len, num_heads) + + # Get the indices of the top k heads for each (batch, seq_len) + _, indices = torch.topk(head_norms_permuted, k=k, dim=-1, largest=True, sorted=False) # Shape: (batch_size, seq_len, k) + + # Create a mask of zeros + head_mask = torch.zeros_like(head_norms_permuted) # Shape: (batch_size, seq_len, num_heads) + + # Scatter ones at the indices of the top k heads + head_mask.scatter_(-1, indices, 1.0) + + # Permute back to shape (batch_size, num_heads, seq_len) + head_mask = head_mask.permute(0, 2, 1) # Shape: (batch_size, num_heads, seq_len) + + # Expand the mask to match attn_output dimensions + # Shape: (batch_size, num_heads, seq_len, 1) + head_mask = head_mask.unsqueeze(-1).type_as(attn_head_output) + + # Apply the mask to zero out the heads not in the top k for each token + attn_output = attn_head_output * head_mask + + # **Head-wise Sparsity Implementation Ends Here** + + # Permute to (batch_size, tgt_len, num_heads, head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + + # Combine heads by reshaping + # Shape: (batch_size, tgt_len, num_heads * head_dim) + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + + # Apply the final linear projection + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + +class OptFlashAttention2(OPTAttention): + """ + OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of flash + attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + attn_router_output: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, seq_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout + ) + + # need to return attn_output and the vector norm of each head + attn_head_output = attn_output + + # **Head-wise Sparsity Implementation Starts Here ** + if self.use_sparse_head_attention: + # Determine the number of heads to keep + k = int(self.sp_threshold * self.num_heads) + # k = max(k, 1) # Ensure at least one head is kept + + # norms = attn_output.norm(dim=-1) # Shape: (batch_size, seq_len, num_heads) + norms = attn_router_output # Shape: (batch_size, seq_len, num_heads) + + # Step 2: For each token, select the top k heads based on the norms + # Use torch.topk to get the indices of the top k heads + topk_values, topk_indices = norms.topk(k, dim=-1) # Shapes: (batch_size, seq_len, k) + + # Step 3: Mask out all non-selected heads + # Create a mask initialized to False (zeros) + mask = torch.zeros_like(norms, dtype=torch.bool) # Shape: (batch_size, seq_len, num_heads) + + # Use scatter_ to set True at the positions of the top k heads + mask.scatter_(-1, topk_indices, True) # Now mask has True at top k head positions + + # Expand the mask to match the dimensions of attn_output + mask = mask.unsqueeze(-1) # Shape: (batch_size, seq_len, num_heads, 1) + + # Apply the mask to the attention output + # This will zero out the outputs of the non-selected heads + attn_output = attn_output * mask # Element-wise multiplication + + # **Head-wise Sparsity Implementation Ends Here ** + + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +OPT_ATTENTION_CLASSES = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, +} + + +class OPTDecoderLayer(nn.Module): + def __init__(self, config: OPTConfig, attn_act_topk: float = 0.0, mlp_act: int = -1): + super().__init__() + self.embed_dim = config.hidden_size + + self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True, sp_threshold=attn_act_topk) + + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # self.sp_threshold = sp_threshold + self.attn_act_topk = attn_act_topk # top k heads to activate, float + self.mlp_act = mlp_act # top k neurons to activate in the MLP, integer + + # --- Initialize routers (hyperparameters can be tuned as needed) --- + self.mha_router = MHARouter( + embed_dim=self.embed_dim, + low_rank_dim=self.embed_dim // 2, # low rank not used, but still need to specify + out_dim=self.self_attn.num_heads, + top_k= attn_act_topk # e.g. activate 50% of heads + ) + + self.mlp_router = MLPRouter( + embed_dim=self.embed_dim, + low_rank_dim=1024, + out_dim=config.ffn_dim, + act_th=0.5 # threshold for neuron activation + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + output_mlp_activation: Optional[bool] = False, + output_attn_output: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # --- Router for Self Attention --- + mha_router_inputs = hidden_states + attn_router_output = self.mha_router(mha_router_inputs) + + # --- Router for MLP --- + mlp_router_inputs = hidden_states.reshape(-1, hidden_states.size(-1)) + mlp_router_output = self.mlp_router(mlp_router_inputs) + + # Self Attention + hidden_states, self_attn_weights, present_key_value, attn_head_output = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + attn_router_output = attn_router_output, + ) + + # to simulate sparse head activation, zero out the hidden states of heads whose norm is below a certain threshold + # .... + + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # need to change this to record the activation output of each head + if output_attn_output: + attn_output = attn_head_output + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + + # apply the router mask here to simulate sparse neuron activation + # mlp_router_mask = mlp_router_output > self.mlp_act + _, index = torch.topk(mlp_router_output, self.mlp_act, dim=1) + mlp_router_mask = torch.zeros_like(mlp_router_output) + mlp_router_mask.scatter_(1, index, 1) + hidden_states = hidden_states * mlp_router_mask + + hidden_states = self.activation_fn(hidden_states) + + # record activation output here + if output_mlp_activation: + mlp_hidden_states = hidden_states + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) # index 0 + + if output_attentions: + outputs += (self_attn_weights,) # always index 1 + + if use_cache: + outputs += (present_key_value,) # can be 1 or 2 + + if output_mlp_activation: + outputs += (mlp_hidden_states,) # can be 1, 2 or 3 + + if output_attn_output: + outputs += (attn_output,) # can be 1, 2, 3 or 4 + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OPTDecoderLayer"] + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig, sp_thresholds = None, mlp_thresholds = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) + self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine + ) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList([OPTDecoderLayer(config, attn_act_topk = sp_thresholds[i], mlp_act= mlp_thresholds[i]) for i in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if self._use_flash_attention_2: + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask + ) + else: + # 4d mask is passed through the layers + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + all_mlp_activations = () if output_mlp_activation else None + all_attn_outputs = () if output_attn_output else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + + # record states + ''' + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_mlp_activation: + all_mlp_activations += (layer_outputs[-1],) + + if output_attn_output: + all_attn_outputs += (layer_outputs[-1],) + + ''' + + hidden_states = layer_outputs[0] + + # Adjust the indices based on the conditions + cache_index = 2 if output_attentions else 1 + mlp_activation_index = 3 if output_attentions and use_cache else 2 if output_attentions or use_cache else 1 + attn_output_index = mlp_activation_index + 1 if output_mlp_activation else mlp_activation_index + + if output_attentions: + all_self_attns += (layer_outputs[1],) # Assuming self-attention weights are always at index 1 + + if use_cache: + next_decoder_cache += (layer_outputs[cache_index],) + + if output_mlp_activation: + all_mlp_activations += (layer_outputs[mlp_activation_index],) + + if output_attn_output: + all_attn_outputs += (layer_outputs[attn_output_index],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mlp_activations=all_mlp_activations, + attn_outputs = all_attn_outputs, + ) + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig, sp_thresholds=None, mlp_thresholds = None): + super().__init__(config) + self.decoder = OPTDecoder(config, sp_thresholds= sp_thresholds, mlp_thresholds=mlp_thresholds) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activations: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activations, + output_attn_output=output_attn_output, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + mlp_activations=decoder_outputs.mlp_activations, + attn_outputs = decoder_outputs.attn_outputs, + ) + + +class SparseOPTForCausalLM(OPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, sp_thresholds=None, mlp_thresholds = None): + super().__init__(config) + self.model = OPTModel(config, sp_thresholds=sp_thresholds, mlp_thresholds=mlp_thresholds) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + mlp_activations=outputs.mlp_activations, + attn_outputs=outputs.attn_outputs, + ) + + + + +@add_start_docstrings( + """ + The OPT Model transformer with a sequence classification head on top (linear layer). + + [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPT_START_DOCSTRING, +) +class OPTForSequenceClassification(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OPTModel(config) + self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + +@add_start_docstrings( + """ + The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + OPT_START_DOCSTRING, +) +class OPTForQuestionAnswering(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.model = OPTModel(config) + self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForQuestionAnswering + >>> import torch + + >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> # note: we are loading a OPTForQuestionAnswering from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random + >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + + >>> inputs = tokenizer(question, text, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> answer_offset = len(tokenizer(question)[0]) + + >>> predict_answer_tokens = inputs.input_ids[ + ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 + ... ] + >>> predicted = tokenizer.decode(predict_answer_tokens) + >>> predicted + ' a nice puppet' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index).to(logits.device) + end_positions = end_positions.clamp(0, ignored_index).to(logits.device) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + +def create_hf_mlp_router_state_dict(router_files_dir): + """ + Loads mlp_router state dicts from files and remaps keys to: + "model.decoder.layers.{layer_num}.mlp_router.{param_name}" + """ + pattern = re.compile(r'mlp_router_(\d+)-[\d.]+-[\d.]+-[\d.]+\.pt$') + state_dict = OrderedDict() + + try: + files = os.listdir(router_files_dir) + except FileNotFoundError: + print(f"Error: Directory '{router_files_dir}' does not exist.") + return None + + router_files = [f for f in files if pattern.match(f)] + if not router_files: + print(f"No mlp_router files found in '{router_files_dir}'.") + return None + + for file_name in sorted(router_files, key=lambda x: int(pattern.match(x).group(1))): + match = pattern.match(file_name) + if not match: + continue + layer_num = int(match.group(1)) + file_path = os.path.join(router_files_dir, file_name) + try: + router_weights = torch.load(file_path, map_location='cpu') + if not isinstance(router_weights, dict): + print(f"Warning: {file_path} does not contain a state dict. Skipping.") + continue + except Exception as e: + print(f"Error loading {file_path}: {e}") + continue + + for param_name, tensor in router_weights.items(): + key = f"model.decoder.layers.{layer_num}.mlp_router.{param_name}" + state_dict[key] = tensor + + print(f"Loaded mlp routers for {len(state_dict) // len(router_weights)} layers.") + return state_dict + + +def create_hf_mha_router_state_dict(router_files_dir): + """ + Loads attn_router state dicts from files and remaps keys to: + "model.decoder.layers.{layer_num}.mha_router.{param_name}" + """ + pattern = re.compile(r'attn_router_(\d+)-[\d.]+-[\d.]+\.pt$') + state_dict = OrderedDict() + loaded_layers = set() + + try: + files = os.listdir(router_files_dir) + except FileNotFoundError: + print(f"Error: Directory '{router_files_dir}' does not exist.") + return None + + router_files = [f for f in files if pattern.match(f)] + if not router_files: + print(f"No attn_router files found in '{router_files_dir}'.") + return None + + for file_name in sorted(router_files, key=lambda x: int(pattern.match(x).group(1))): + match = pattern.match(file_name) + if not match: + continue + layer_num = int(match.group(1)) + if layer_num in loaded_layers: + print(f"Warning: Duplicate router for layer {layer_num} in '{file_name}'. Skipping.") + continue + file_path = os.path.join(router_files_dir, file_name) + try: + router_weights = torch.load(file_path, map_location='cpu') + if not isinstance(router_weights, dict): + print(f"Warning: {file_path} does not contain a state dict. Skipping.") + continue + except Exception as e: + print(f"Error loading {file_path}: {e}") + continue + + for param_name, tensor in router_weights.items(): + key = f"model.decoder.layers.{layer_num}.mha_router.{param_name}" + state_dict[key] = tensor + + loaded_layers.add(layer_num) + + print(f"Loaded mha routers for {len(loaded_layers)} layers.") + return state_dict diff --git a/hf_models/opt/modeling_outputs.py b/hf_models/opt/modeling_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..56b954de3966e698ae6c4649327623097512a91d --- /dev/null +++ b/hf_models/opt/modeling_outputs.py @@ -0,0 +1,138 @@ +from dataclasses import dataclass +from typing import Optional, Tuple +import torch + +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) + +@dataclass +class ExtendedBaseModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + router_inputs: Optional[Tuple[torch.FloatTensor, ...]] = None + mlp_activations: Optional[Tuple[torch.FloatTensor, ...]] = None + attn_outputs: Optional[Tuple[torch.FloatTensor, ...]] = None + +@dataclass +class ExtendedCausalLMOutputWithPast(CausalLMOutputWithPast): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + router_inputs: Optional[Tuple[torch.FloatTensor, ...]] = None + mlp_activations: Optional[Tuple[torch.FloatTensor, ...]] = None + attn_outputs: Optional[Tuple[torch.FloatTensor, ...]] = None + +@dataclass +class ExtendedQuestionAnsweringModelOutput(QuestionAnsweringModelOutput): + """ + Base class for outputs of question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + router_inputs: Optional[Tuple[torch.FloatTensor, ...]] = None + mlp_activations: Optional[Tuple[torch.FloatTensor, ...]] = None + attn_outputs: Optional[Tuple[torch.FloatTensor, ...]] = None + +@dataclass +class ExtendedSequenceClassifierOutputWithPast(SequenceClassifierOutputWithPast): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + router_inputs: Optional[Tuple[torch.FloatTensor, ...]] = None + mlp_activations: Optional[Tuple[torch.FloatTensor, ...]] = None + attn_outputs: Optional[Tuple[torch.FloatTensor, ...]] = None \ No newline at end of file diff --git a/hf_models/opt/modeling_sparse_opt.py b/hf_models/opt/modeling_sparse_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..2314f5bba7ca81a5dc299419724758a9d7147fbc --- /dev/null +++ b/hf_models/opt/modeling_sparse_opt.py @@ -0,0 +1,1790 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OPT model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + +''' +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +''' + +from .modeling_outputs import ExtendedBaseModelOutputWithPast as BaseModelOutputWithPast +from .modeling_outputs import ExtendedCausalLMOutputWithPast as CausalLMOutputWithPast +from .modeling_outputs import ExtendedQuestionAnsweringModelOutput as QuestionAnsweringModelOutput +from .modeling_outputs import ExtendedSequenceClassifierOutputWithPast as SequenceClassifierOutputWithPast + +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_opt import OPTConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" +_SEQ_CLASS_EXPECTED_LOSS = 1.71 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: OPTConfig, + is_decoder: bool = False, + sp_threshold: float = 0.0, # Threshold for head-wise sparsity + **kwargs, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.enable_bias = config.enable_bias + + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + + # Sparsity related attributes + self.sp_threshold = sp_threshold + if sp_threshold > 0.0: + self.use_sparse_head_attention = True + else: + self.use_sparse_head_attention = False + + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def dense_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + + # we need to return the attn_output and observe the vector norm of each head + # ... + attn_head_output = attn_output + + # apply sparse head activation here, mask the hidden states of heads whose norm is below a certain threshold + + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Forward pass with head-wise sparsity based on output norms.""" + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_head_output = attn_output + + if self.use_sparse_head_attention or self.sp_threshold > 0.0: + # **Head-wise Sparsity Implementation Starts Here** + + # Compute the L2 norm for each head's output across the head_dim for each token + # Shape: (batch_size, num_heads, tgt_len) + head_norms = attn_output.norm(p=2, dim=-1) # L2 norm + + # Create a mask where heads with norm >= threshold are kept + # Shape: (batch_size, num_heads, tgt_len) + head_mask = head_norms >= self.sp_threshold + + # Expand the mask to match the attn_output dimensions + # Shape: (batch_size, num_heads, tgt_len, 1) + head_mask = head_mask.unsqueeze(-1).type_as(attn_output) + + # Apply the mask to zero out the heads below the threshold for each token + attn_output = attn_output * head_mask # Zero out inactive heads + + # **Optional: Store the mask for analysis or further processing** + # self.current_head_mask = head_mask # Uncomment if you want to store the mask + + # Recompute hidden_states from the masked attn_output + # Permute to (batch_size, tgt_len, num_heads, head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + + # Combine heads by reshaping + # Shape: (batch_size, tgt_len, num_heads * head_dim) + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + + + + # Apply the final linear projection + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + +class OptFlashAttention2(OPTAttention): + """ + OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of flash + attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def dense_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, _, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout + ) + + # need to return attn_output and the vector norm of each head + attn_head_output = attn_output + + # apply sparse head activation here, mask the hidden states of heads whose norm is below a certain threshold + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, _, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout + ) + + # need to return attn_output and the vector norm of each head + attn_head_output = attn_output + + # apply sparse head activation here, mask the hidden states of heads whose norm is below a certain threshold + if self.use_sparse_head_attention or self.sp_threshold > 0.0: + # **Head-wise Sparsity Implementation Starts Here** + + # Compute the L2 norm for each head's output across the head_dim for each token + # Shape: (batch_size, num_heads, tgt_len) + head_norms = attn_output.norm(p=2, dim=-1) # L2 norm + + # Create a mask where heads with norm >= threshold are kept + # Shape: (batch_size, num_heads, tgt_len) + head_mask = head_norms >= self.sp_threshold + + # Expand the mask to match the attn_output dimensions + # Shape: (batch_size, num_heads, tgt_len, 1) + head_mask = head_mask.unsqueeze(-1).type_as(attn_output) + + # Apply the mask to zero out the heads below the threshold for each token + attn_output = attn_output * head_mask # Zero out inactive heads + + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +OPT_ATTENTION_CLASSES = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, +} + + +class OPTDecoderLayer(nn.Module): + def __init__(self, config: OPTConfig, sp_threshold: float = 0.0): + super().__init__() + self.embed_dim = config.hidden_size + + self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True, sp_threshold=sp_threshold) + + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + self.sp_threshold = sp_threshold + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + output_mlp_activation: Optional[bool] = False, + output_attn_output: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value, attn_head_output = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # to simulate sparse head activation, zero out the hidden states of heads whose norm is below a certain threshold + # .... + + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # need to change this to record the activation output of each head + if output_attn_output: + attn_output = attn_head_output + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + # record activation output here + if output_mlp_activation: + mlp_hidden_states = hidden_states + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) # index 0 + + if output_attentions: + outputs += (self_attn_weights,) # always index 1 + + if use_cache: + outputs += (present_key_value,) # can be 1 or 2 + + if output_mlp_activation: + outputs += (mlp_hidden_states,) # can be 1, 2 or 3 + + if output_attn_output: + outputs += (attn_output,) # can be 1, 2, 3 or 4 + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OPTDecoderLayer"] + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig, sp_thresholds = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) + self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine + ) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList([OPTDecoderLayer(config, sp_threshold = sp_thresholds[i]) for i in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if self._use_flash_attention_2: + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask + ) + else: + # 4d mask is passed through the layers + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + all_mlp_activations = () if output_mlp_activation else None + all_attn_outputs = () if output_attn_output else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + + # record states + ''' + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_mlp_activation: + all_mlp_activations += (layer_outputs[-1],) + + if output_attn_output: + all_attn_outputs += (layer_outputs[-1],) + + ''' + + hidden_states = layer_outputs[0] + + # Adjust the indices based on the conditions + cache_index = 2 if output_attentions else 1 + mlp_activation_index = 3 if output_attentions and use_cache else 2 if output_attentions or use_cache else 1 + attn_output_index = mlp_activation_index + 1 if output_mlp_activation else mlp_activation_index + + if output_attentions: + all_self_attns += (layer_outputs[1],) # Assuming self-attention weights are always at index 1 + + if use_cache: + next_decoder_cache += (layer_outputs[cache_index],) + + if output_mlp_activation: + all_mlp_activations += (layer_outputs[mlp_activation_index],) + + if output_attn_output: + all_attn_outputs += (layer_outputs[attn_output_index],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mlp_activations=all_mlp_activations, + attn_outputs = all_attn_outputs, + ) + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig, sp_thresholds=None): + super().__init__(config) + self.decoder = OPTDecoder(config, sp_thresholds) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activations: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activations, + output_attn_output=output_attn_output, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + mlp_activations=decoder_outputs.mlp_activations, + attn_outputs = decoder_outputs.attn_outputs, + ) + + + + +class SparseOPTForCausalLM(OPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, sp_thresholds=None): + super().__init__(config) + self.model = OPTModel(config, sp_thresholds) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + mlp_activations=outputs.mlp_activations, + attn_outputs=outputs.attn_outputs, + ) + + + + + +@add_start_docstrings( + """ + The OPT Model transformer with a sequence classification head on top (linear layer). + + [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPT_START_DOCSTRING, +) +class OPTForSequenceClassification(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OPTModel(config) + self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + +@add_start_docstrings( + """ + The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + OPT_START_DOCSTRING, +) +class OPTForQuestionAnswering(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.model = OPTModel(config) + self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForQuestionAnswering + >>> import torch + + >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> # note: we are loading a OPTForQuestionAnswering from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random + >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + + >>> inputs = tokenizer(question, text, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> answer_offset = len(tokenizer(question)[0]) + + >>> predict_answer_tokens = inputs.input_ids[ + ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 + ... ] + >>> predicted = tokenizer.decode(predict_answer_tokens) + >>> predicted + ' a nice puppet' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index).to(logits.device) + end_positions = end_positions.clamp(0, ignored_index).to(logits.device) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value diff --git a/hf_models/opt/modeling_sparse_opt_topk.py b/hf_models/opt/modeling_sparse_opt_topk.py new file mode 100644 index 0000000000000000000000000000000000000000..686759f4a53b982a14c1b6d7cb77db3235b7785b --- /dev/null +++ b/hf_models/opt/modeling_sparse_opt_topk.py @@ -0,0 +1,1579 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OPT model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + +''' +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +''' + +from .modeling_outputs import ExtendedBaseModelOutputWithPast as BaseModelOutputWithPast +from .modeling_outputs import ExtendedCausalLMOutputWithPast as CausalLMOutputWithPast +from .modeling_outputs import ExtendedQuestionAnsweringModelOutput as QuestionAnsweringModelOutput +from .modeling_outputs import ExtendedSequenceClassifierOutputWithPast as SequenceClassifierOutputWithPast + +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_opt import OPTConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" +_SEQ_CLASS_EXPECTED_LOSS = 1.71 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: OPTConfig, + is_decoder: bool = False, + sp_threshold: float = 0.0, # sparisity threshold is used to select top-k heads based on the output norms + **kwargs, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.enable_bias = config.enable_bias + + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + + # Sparsity related attributes + self.sp_threshold = sp_threshold + self.use_sparse_head_attention = True + + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Forward pass with head-wise sparsity based on output norms.""" + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_head_output = attn_output + + # **Head-wise Sparsity Implementation Starts Here** + # Determine the number of heads to keep + k = int(self.sp_threshold * self.num_heads) + k = max(k, 1) # Ensure at least one head is kept + + # Compute the L2 norm for each head's output across the head_dim + # Shape: (batch_size, num_heads, seq_len) + head_norms = attn_head_output.norm(p=2, dim=-1) # Norm over head_dim + + # Permute to shape (batch_size, seq_len, num_heads) to select top k heads per token + head_norms_permuted = head_norms.permute(0, 2, 1) # Shape: (batch_size, seq_len, num_heads) + + # Get the indices of the top k heads for each (batch, seq_len) + _, indices = torch.topk(head_norms_permuted, k=k, dim=-1, largest=True, sorted=False) # Shape: (batch_size, seq_len, k) + + # Create a mask of zeros + head_mask = torch.zeros_like(head_norms_permuted) # Shape: (batch_size, seq_len, num_heads) + + # Scatter ones at the indices of the top k heads + head_mask.scatter_(-1, indices, 1.0) + + # Permute back to shape (batch_size, num_heads, seq_len) + head_mask = head_mask.permute(0, 2, 1) # Shape: (batch_size, num_heads, seq_len) + + # Expand the mask to match attn_output dimensions + # Shape: (batch_size, num_heads, seq_len, 1) + head_mask = head_mask.unsqueeze(-1).type_as(attn_head_output) + + # Apply the mask to zero out the heads not in the top k for each token + attn_output = attn_head_output * head_mask + + # **Head-wise Sparsity Implementation Ends Here** + + # Permute to (batch_size, tgt_len, num_heads, head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + + # Combine heads by reshaping + # Shape: (batch_size, tgt_len, num_heads * head_dim) + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + + # Apply the final linear projection + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + +class OptFlashAttention2(OPTAttention): + """ + OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of flash + attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, seq_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout + ) + + # need to return attn_output and the vector norm of each head + attn_head_output = attn_output + + # **Head-wise Sparsity Implementation Starts Here ** + if self.use_sparse_head_attention: + # Determine the number of heads to keep + k = int(self.sp_threshold * self.num_heads) + # k = max(k, 1) # Ensure at least one head is kept + + norms = attn_output.norm(dim=-1) # Shape: (batch_size, seq_len, num_heads) + + # Step 2: For each token, select the top k heads based on the norms + # Use torch.topk to get the indices of the top k heads + topk_values, topk_indices = norms.topk(k, dim=-1) # Shapes: (batch_size, seq_len, k) + + # Step 3: Mask out all non-selected heads + # Create a mask initialized to False (zeros) + mask = torch.zeros_like(norms, dtype=torch.bool) # Shape: (batch_size, seq_len, num_heads) + + # Use scatter_ to set True at the positions of the top k heads + mask.scatter_(-1, topk_indices, True) # Now mask has True at top k head positions + + # Expand the mask to match the dimensions of attn_output + mask = mask.unsqueeze(-1) # Shape: (batch_size, seq_len, num_heads, 1) + + # dense prefill sparse decode + prefill_len = int(seq_len * 0.8) + mask[:, :prefill_len, :, :] = True + + # Apply the mask to the attention output + # This will zero out the outputs of the non-selected heads + attn_output = attn_output * mask # Element-wise multiplication + + # **Head-wise Sparsity Implementation Ends Here ** + + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value, attn_head_output + + + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +OPT_ATTENTION_CLASSES = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, +} + + +class OPTDecoderLayer(nn.Module): + def __init__(self, config: OPTConfig, sp_threshold: float = 0.0): + super().__init__() + self.embed_dim = config.hidden_size + + self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True, sp_threshold=sp_threshold) + + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + self.sp_threshold = sp_threshold + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + output_mlp_activation: Optional[bool] = False, + output_attn_output: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value, attn_head_output = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # to simulate sparse head activation, zero out the hidden states of heads whose norm is below a certain threshold + # .... + + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # need to change this to record the activation output of each head + if output_attn_output: + attn_output = attn_head_output + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + # record activation output here + if output_mlp_activation: + mlp_hidden_states = hidden_states + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) # index 0 + + if output_attentions: + outputs += (self_attn_weights,) # always index 1 + + if use_cache: + outputs += (present_key_value,) # can be 1 or 2 + + if output_mlp_activation: + outputs += (mlp_hidden_states,) # can be 1, 2 or 3 + + if output_attn_output: + outputs += (attn_output,) # can be 1, 2, 3 or 4 + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OPTDecoderLayer"] + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig, sp_thresholds = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) + self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine + ) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList([OPTDecoderLayer(config, sp_threshold = sp_thresholds[i]) for i in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if self._use_flash_attention_2: + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask + ) + else: + # 4d mask is passed through the layers + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + all_mlp_activations = () if output_mlp_activation else None + all_attn_outputs = () if output_attn_output else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + + # record states + ''' + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_mlp_activation: + all_mlp_activations += (layer_outputs[-1],) + + if output_attn_output: + all_attn_outputs += (layer_outputs[-1],) + + ''' + + hidden_states = layer_outputs[0] + + # Adjust the indices based on the conditions + cache_index = 2 if output_attentions else 1 + mlp_activation_index = 3 if output_attentions and use_cache else 2 if output_attentions or use_cache else 1 + attn_output_index = mlp_activation_index + 1 if output_mlp_activation else mlp_activation_index + + if output_attentions: + all_self_attns += (layer_outputs[1],) # Assuming self-attention weights are always at index 1 + + if use_cache: + next_decoder_cache += (layer_outputs[cache_index],) + + if output_mlp_activation: + all_mlp_activations += (layer_outputs[mlp_activation_index],) + + if output_attn_output: + all_attn_outputs += (layer_outputs[attn_output_index],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mlp_activations=all_mlp_activations, + attn_outputs = all_attn_outputs, + ) + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig, sp_thresholds=None): + super().__init__(config) + self.decoder = OPTDecoder(config, sp_thresholds) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activations: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activations, + output_attn_output=output_attn_output, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + mlp_activations=decoder_outputs.mlp_activations, + attn_outputs = decoder_outputs.attn_outputs, + ) + +class SparseOPTForCausalLM(OPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, sp_thresholds=None): + super().__init__(config) + self.model = OPTModel(config, sp_thresholds) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + mlp_activations=outputs.mlp_activations, + attn_outputs=outputs.attn_outputs, + ) + + + + + +@add_start_docstrings( + """ + The OPT Model transformer with a sequence classification head on top (linear layer). + + [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPT_START_DOCSTRING, +) +class OPTForSequenceClassification(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OPTModel(config) + self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + +@add_start_docstrings( + """ + The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + OPT_START_DOCSTRING, +) +class OPTForQuestionAnswering(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.model = OPTModel(config) + self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_mlp_activation: Optional[bool] = None, + output_attn_output: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForQuestionAnswering + >>> import torch + + >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> # note: we are loading a OPTForQuestionAnswering from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random + >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + + >>> inputs = tokenizer(question, text, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> answer_offset = len(tokenizer(question)[0]) + + >>> predict_answer_tokens = inputs.input_ids[ + ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 + ... ] + >>> predicted = tokenizer.decode(predict_answer_tokens) + >>> predicted + ' a nice puppet' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_mlp_activation=output_mlp_activation, + output_attn_output=output_attn_output, + ) + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index).to(logits.device) + end_positions = end_positions.clamp(0, ignored_index).to(logits.device) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + mlp_activations=transformer_outputs.mlp_activations, + attn_outputs=transformer_outputs.attn_outputs, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value diff --git a/hf_models/opt/modeling_tf_opt.py b/hf_models/opt/modeling_tf_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..9c5dfa4ade61078fd8c871a5c3a5fb8b307a85ea --- /dev/null +++ b/hf_models/opt/modeling_tf_opt.py @@ -0,0 +1,1094 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 OPT model.""" + +from __future__ import annotations + +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast + +# Public API +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + TFSharedEmbeddings, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_opt import OPTConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# Causal LM output +_CAUSAL_LM_EXPECTED_OUTPUT = ( + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." +) + +LARGE_NEGATIVE = -1e8 + + +def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] + # We need triu with k = 1 but TF expects known compile-time dims for that, so we hack around it + mask = tf.fill((tgt_len, tgt_len), tf.cast(LARGE_NEGATIVE, tf.float32)) + mask = tf.linalg.band_part(mask, 0, -1) - tf.linalg.band_part(mask, 0, 0) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) + + return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) + + +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + src_len = shape_list(mask)[1] + tgt_len = tgt_len if tgt_len is not None else src_len + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) + + return (one_cst - expanded_mask) * LARGE_NEGATIVE + + +class TFOPTLearnedPositionalEmbedding(keras.layers.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs) + + def call(self, attention_mask, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = tf.cast(attention_mask, tf.int64) + + # create positions depending on attention_mask + positions = tf.math.cumsum(attention_mask, axis=1) * attention_mask - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().call(positions + self.offset) + + +# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->OPT +class TFOPTAttention(keras.layers.Layer): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = keras.layers.Dropout(dropout) + self.head_dim = embed_dim // num_heads + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") + self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") + self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") + self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: tf.Tensor | None = None, + past_key_value: Tuple[Tuple[tf.Tensor]] | None = None, + attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Tuple[tf.Tensor, tf.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {shape_list(attn_weights)}" + ), + ) + + if attention_mask is not None: + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(attention_mask)}" + ), + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = stable_softmax(attn_weights, axis=-1) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_probs = self.dropout(attn_weights, training=training) + attn_output = tf.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {shape_list(attn_output)}" + ), + ) + + attn_output = tf.transpose( + tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) + ) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.embed_dim]) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.embed_dim]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.embed_dim]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.embed_dim]) + + +class TFOPTDecoderLayer(keras.layers.Layer): + def __init__(self, config: OPTConfig, **kwargs): + super().__init__(**kwargs) + self.do_layer_norm_before = config.do_layer_norm_before + self.embed_dim = config.hidden_size + self.self_attn = TFOPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + name="self_attn", + is_decoder=True, + ) + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.fc1 = keras.layers.Dense(config.ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: np.ndarray | tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + training: Optional[bool] = False, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`, *optional*): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + """ + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + return (hidden_states, self_attn_weights, present_key_value) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +OPT_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`OPTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class TFOPTPreTrainedModel(TFPreTrainedModel): + """ + TFOPT Pretrained Model that inheritates from transformers.TFPreTrainedModel + + Args: + config: OPTConfig + """ + + config_class = OPTConfig + base_model_prefix = "model" + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@keras_serializable +class TFOPTDecoder(keras.layers.Layer): + config_class = OPTConfig + + def __init__(self, config: OPTConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.padding_idx = config.pad_token_id + self.layerdrop = config.layerdrop + num_embeddings = config.max_position_embeddings + self.embed_tokens = TFSharedEmbeddings( + config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="embed_tokens" + ) + self.embed_positions = TFOPTLearnedPositionalEmbedding( + num_embeddings, + config.hidden_size, + name="embed_positions", + ) + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + else: + self.final_layer_norm = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False) + self.project_in = keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False) + + else: + self.project_in = None + self.project_out = None + + self.layers = [TFOPTDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)] + self.dropout = keras.layers.Dropout(config.dropout) + + def get_embed_tokens(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens.vocab_size = new_embeddings.shape[0] + self.embed_tokens.weight = new_embeddings + + def get_input_embeddings(self): + return self.embed_tokens + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length): + # create causal mask + # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + _, seq_length = input_shape + tf.debugging.assert_equal( + seq_length + past_key_values_length, + shape_list(attention_mask)[1], + message="Attention mask shape should be (batch_size, seq_length + past_key_values_length)" + f" but is {shape_list(attention_mask)[1]} with input_ids shape {input_shape} and past length" + f" {past_key_values_length}.", + ) + + expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1]) + if seq_length > 1: + combined_attention_mask = ( + _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + expanded_attn_mask + ) + else: + combined_attention_mask = expanded_attn_mask + + return combined_attention_mask + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]: + r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.vocab_size) + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is None: + attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.bool) + else: + tf.debugging.assert_equal( + shape_list(attention_mask)[1], + past_key_values_length + input_shape[1], + message=( + f"The provided attention mask has length {tf.shape(attention_mask)[1]}, but its length should be " + f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)" + ), + ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + present_key_values = () if use_cache else None + + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired + for attn_mask_name, attn_mask in [("head_mask", head_mask)]: + if attn_mask is not None: + tf.debugging.assert_equal( + shape_list(attn_mask)[0], + len(self.layers), + message=( + f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(attn_mask)[0]}." + ), + ) + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + past_key_value=past_key_value, + ) + + if use_cache: + present_key_values += (present_key_value,) + + if output_attentions: + all_self_attns += (layer_self_attn,) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns] if v is not None + ) + + else: + return TFBaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=present_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_tokens", None) is not None: + with tf.name_scope(self.embed_tokens.name): + self.embed_tokens.build(None) + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.config.hidden_size]) + if getattr(self, "project_out", None) is not None: + with tf.name_scope(self.project_out.name): + self.project_out.build([None, None, self.config.hidden_size]) + if getattr(self, "project_in", None) is not None: + with tf.name_scope(self.project_in.name): + self.project_in.build([None, None, self.config.word_embed_proj_dim]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFOPTMainLayer(keras.layers.Layer): + config_class = OPTConfig + + def __init__(self, config: OPTConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.decoder = TFOPTDecoder(config, name="decoder") + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.decoder.set_input_embeddings(new_embeddings) + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + **kwargs, + ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.decoder( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return outputs + + return TFBaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings( + "The bare TF OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +@keras_serializable +class TFOPTModel(TFOPTPreTrainedModel): + config_class = OPTConfig + + def __init__(self, config: OPTConfig, **kwargs): + super().__init__(config, **kwargs) + self.config = config + self.model = TFOPTMainLayer(config, name="model") + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.model.set_input_embeddings(new_embeddings) + + @unpack_inputs + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + **kwargs, + ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return outputs + + return TFBaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFBaseModelOutputWithPast( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +@add_start_docstrings( + """ + The OPT Model transformer with a language modeling head on top. + """, + OPT_START_DOCSTRING, +) +@keras_serializable +class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss): + config_class = OPTConfig + + def __init__(self, config: OPTConfig, **kwargs): + super().__init__(config, **kwargs) + self.config = config + self.model = TFOPTMainLayer(config, name="model") + + def get_output_embeddings(self): + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs): + attention_mask = kwargs.get("attention_mask", None) + + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + inputs = tf.expand_dims(inputs[:, -1], -1) + + return { + "input_ids": inputs, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + @unpack_inputs + @replace_return_docstrings(output_type=TFCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_CAUSAL_LM_EXPECTED_OUTPUT, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + **kwargs, + ) -> Union[TFCausalLMOutputWithPast, Tuple[tf.Tensor]]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + logits = self.model.decoder.embed_tokens(outputs[0], mode="linear") + loss = None + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels, shifted_logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFCausalLMOutputWithPast( + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + loss=output.loss, + logits=output.logits, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) diff --git a/model_sparse_generation.py b/model_sparse_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f0f707fb9f19722215036f89479d8938b47d66 --- /dev/null +++ b/model_sparse_generation.py @@ -0,0 +1,110 @@ +# python model_sparse_generation.py --model_index 11 --attn_ckpt_dir --attn_topk 0.5 +# python model_sparse_generation.py --model_index 5 --mlp_ckpt_dir --attn_ckpt_dir --batch_stats_dir configs/mlp_router/opt-6.7b --attn_topk 0.5 + +import torch +import argparse + +from HybridTensor.utils.utils import _get_device +from HybridTensor.utils.activations import MODELS +from HybridTensor.models.opt import SparseConfig, build_sparse_opt + +from HybridTensor.models.llama import build_sparse_llama +from HybridTensor.benchmarks.generation.gen_util import tokenize_dataset, get_random_batch +from HybridTensor.utils.activations import build_mlp_topk_lookup +from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv + +from datasets import load_dataset + +from transformers.models.opt import OPTConfig +from transformers import AutoTokenizer +from flash_attn.models.opt import opt_config_to_gpt2_config +from flash_attn.utils.generation import update_graph_cache + +def update_router_config(model, num_layers, mlp_topk_lookup, attn_topk): + for i in range(num_layers): + if mlp_topk_lookup is not None: + model.transformer.layers[i].mlp_topk = mlp_topk_lookup[i] + # model.transformer.layers[i].mlp_topk = 512 + model.transformer.layers[i].mha_router.topk = attn_topk + + # dense attention in layer 0 + model.transformer.layers[0].mha_router.topk = 1.0 + +def arg_parser(): + parser = argparse.ArgumentParser(description='Inference benchmarking') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--model_index', type=int, default=5) + parser.add_argument('--print_results', type=bool, default=True) + parser.add_argument('--iterations', type=int, default=1) + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model') + parser.add_argument('--mlp_ckpt_dir', type=str, default='') + parser.add_argument('--attn_ckpt_dir', type=str, default='') + parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router/opt-6.7b') + parser.add_argument('--delta', type=int, default=256, help='Delta value for MLP topk calculation') + parser.add_argument('--use_cuda_graph', type=bool, default=False, help='Use CUDA graph for inference') + + return parser.parse_args() + +if __name__ == "__main__": + args = arg_parser() + model_name = MODELS[args.model_index-1] + print(f"Model name: {model_name}") + dtype = torch.float16 + device= _get_device(args.gpu) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + if "llama" in model_name.lower(): + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.pad_token_id + + if "llama" in model_name: + model = build_sparse_llama(args, model_name, + args.attn_ckpt_dir, + device = device, dtype=dtype) + update_router_config(model, model.config.n_layer, None, args.attn_topk) # this sets the router config for all layers using a single config + + else: + mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, args.batch_size) + model = build_sparse_opt(args, model_name, + args.mlp_ckpt_dir, + args.attn_ckpt_dir, + device = device, dtype=dtype) + update_router_config(model, model.config.n_layer, mlp_topk_lookup, args.attn_topk) # this sets the router config for all layers using a single config + + + model.eval() + print(model) + + # test input + input_texts = ["The future of AI is", "In a distant galaxy,"] + # input_ids = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=False).input_ids.to(device) + encoding = tokenizer( + input_texts, + return_tensors="pt", + padding=True, + truncation=False # Or True with a specified max_length if inputs can be long + ) + input_ids = encoding.input_ids.to(device) + attention_mask = encoding.attention_mask.to(device) # Crucial: get the attention mask + position_ids = None + eos_token_id = tokenizer.eos_token_id + return_dict_in_generate=True, + max_length = 50 + + # Generate output + with torch.no_grad(): + out = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + cg=False, + ) + + print(tokenizer.batch_decode(out.sequences.tolist())) diff --git a/run_sparse_attn.py b/run_sparse_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a5b446c67bdd4c58a8fd7ae774a619157c5bd9 --- /dev/null +++ b/run_sparse_attn.py @@ -0,0 +1,67 @@ +# python run_sparse_attn.py --in_features 8192 --batch_size 16 --seq_len 1920 --attn_topk 0.5 + +import math +import torch +from HybridTensor.modules.SelectiveMHA import SMHA, _update_kv_cache +from HybridTensor.utils.utils import arg_parser, generate_random_BH_index +from HybridTensor.utils.profiling import cuda_profiler +from HybridTensor.utils.generation import InferenceParams + +if __name__ == "__main__": + args = arg_parser() + + max_seqlen = args.seq_len + 128 + max_batch_size = args.batch_size + device = torch.device(f"cuda:{args.device}") + + # simulates SelectiveMHA inference generation stage + inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size) + nheads = args.in_features // 128 + softmax_scale = 1 / (128 ** 0.5) + rotary_emb_dim = 0 + + mha = SMHA( + embed_dim=args.in_features, + num_heads=nheads, + num_heads_kv=None, + causal=True, + layer_idx=0, + use_flash_attn=True, + softmax_scale=softmax_scale, + return_residual=False, + rotary_emb_dim=rotary_emb_dim, + device=device, + dtype=torch.float16, + ) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + with torch.no_grad(): + # prefill stage to generate kv cache for all batches + og_x = torch.randn(args.batch_size, args.seq_len, args.in_features, device=device, dtype=torch.float16, requires_grad=False) + + # simulate kv cache, bug in flash_attn for larger batches + kv = torch.randn(args.batch_size, args.seq_len, 2, nheads, 128, device=device, dtype=torch.float16, requires_grad=False) + _ = _update_kv_cache(kv, inference_params, 0) + + # increment the sequence length to move to the generation stage + inference_params.seqlen_offset += args.seq_len + + input_x = torch.randn(args.batch_size, 1, args.in_features, device=device, dtype=torch.float16, requires_grad=False) + selected_heads = math.ceil(nheads * args.head_density) + + # generate batch_head_idx for SelectiveMHA + # batch_head_index = generate_BH_index(args.batch_size, nheads, selected_heads, device=device) + batch_head_index = generate_random_BH_index(args.batch_size, nheads, selected_heads, device=device) + + # generatation stage Standard MHA, batch_head_idx=None uses dense attention + out, standard_time_ms = cuda_profiler(mha, input_x, inference_params=inference_params, batch_head_idx=None) + print(f"Standard MHA time: {standard_time_ms:.3f} ms") + + # generatation stage SelectiveMHA, pass batch_head_idx to use selective head attention + out, select_time_ms = cuda_profiler(mha, input_x, inference_params=inference_params, batch_head_idx=batch_head_index) + print(f"SelectMHA time: {select_time_ms:.3f} ms") + + speedup = standard_time_ms / select_time_ms + print(f"Speedup: {speedup:.3f}") + diff --git a/run_sparse_mlp.py b/run_sparse_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..fcaaaffaf9f3ade23158e148c553d7a2d05577d7 --- /dev/null +++ b/run_sparse_mlp.py @@ -0,0 +1,72 @@ +# python run_sparse_mlp.py --in_features 8192 --batch_size 16 --index_size 8192 + + +import torch +from HybridTensor.modules.SelectiveMLP import SelectiveMLP +from HybridTensor.utils.profiling import cuda_profiler +from HybridTensor.utils.utils import arg_parser, sparse_index +import torch.nn as nn +import torch.nn.functional as F + +# standard MLP implementation +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation=F.gelu, + bias1=True, + bias2=True, + return_residual=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features if out_features is not None else in_features + hidden_features = hidden_features if hidden_features is not None else in_features * 4 + self.return_residual = return_residual + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) + self.activation = activation + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + y = self.activation(y) + y = self.fc2(y) + return y if not self.return_residual else (y, x) + +if __name__ == "__main__": + args = arg_parser() + + bias = True if args.bias > 0 else False + + in_features = args.in_features + hidden_features = in_features * 4 + out_features = in_features + activation="relu" + device = torch.device("cuda") + + sparse_mlp = SelectiveMLP( + in_features, hidden_features, out_features, activation="relu", use_heuristic=False, bias1=bias, bias2=bias, device=device, dtype=torch.float16 + ) + + activation_fn = F.relu if activation == "relu" else F.gelu + dense_mlp = Mlp(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias1=bias, bias2=bias, activation=activation_fn, device=device, dtype=torch.float16) + + + # Create random input tensor + x = torch.randn(args.batch_size, args.in_features, device="cuda", dtype=torch.float16) + index_vec, _ = sparse_index(args.index_size, args.in_features*4) + + # dense mlp time + dense_mlp_out, dense_mlp_time = cuda_profiler(dense_mlp, x) + + # sparse mlp time + sparse_mlp_out, sparse_mlp_time = cuda_profiler(sparse_mlp, x, index_vec) + print(f"Dense MLP time: {dense_mlp_time:.4f} ms") + print(f"Sparse MLP time: {sparse_mlp_time:.4f} ms") + print(f"Speedup: {dense_mlp_time/sparse_mlp_time:.2f}x") + + \ No newline at end of file diff --git a/run_sparse_transformer_block.py b/run_sparse_transformer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..fed87cdf5c75ba13e67a4653ec3c3c520f93d0ef --- /dev/null +++ b/run_sparse_transformer_block.py @@ -0,0 +1,239 @@ +# python run_sparse_transformer_block.py --in_features 8192 --batch_size 32 --seq_len 1920 --index_size 8192 --attn_topk 0.5 + +import torch +import time +from HybridTensor.utils.utils import arg_parser, generate_random_BH_index +from HybridTensor.utils.profiling import cuda_profiler +from HybridTensor.utils.generation import InferenceParams +from HybridTensor.utils.utils import sparse_index +from HybridTensor.utils.utils import _get_device +from HybridTensor.models.create_sparse_model import create_block + +class Config: + def __init__(self, in_features=8192): + self.hidden_size = in_features + self.num_attention_heads = in_features // 128 + self.head_dim = self.hidden_size // self.num_attention_heads + self.scale_attn_weights = True + self.mup_scale_qk_dot_by_d = False + self.mup_attn_multiplier = 1.0 + self.scale_attn_by_inverse_layer_idx = False + self.attn_dwconv = False + self.qkv_proj_bias = True + self.out_proj_bias = True + self.rotary_emb_fraction = 0.0 + self.rotary_emb_base = 10000.0 + self.rotary_emb_scale_base = None + self.rotary_emb_interleaved = False + self.use_alibi = False + self.window_size = (-1, -1) + self.use_flash_attn = True + self.fused_bias_fc = True + self.mlp_sparse = True + self.att_sparse = True + self.attn_pdrop = 0.1 + self.n_inner = None # Can be overridden + self.activation_function = "relu" + self.fused_mlp = True + self.mlp_checkpoint_lvl = 0 + self.sequence_parallel = False + self.layer_norm_epsilon = 1e-5 + self.residual_in_fp32 = False + self.fused_dropout_add_ln = True + self.resid_pdrop = 0.1 + self.embd_pdrop = 0.1 + self.prenorm = True + self.parallel_block = False + +class SparseConfig: + def __init__(self): + self.mlp_low_rank_dim = 1024 + self.attn_low_rank_dim = 128 # not used + self.attn_topk = 0.5 + +if __name__ =="__main__": + # Instantiate sample configs + args = arg_parser() + + config = Config() + sp_config = SparseConfig() + sp_config.attn_topk = args.attn_topk + + config.hidden_size = args.in_features + config.num_attention_heads = args.in_features // 128 + config.use_heuristic = False # use pre-compiled heuristic or complie new one during runtime + + # Example device and dtype + device = _get_device(args.device) + dtype = torch.float16 + + # Test create_block + sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype) + sparse_block.eval() + sparse_block.mlp_topk = args.index_size + sparse_block.mlp.use_heuristic = False + + regular_config = config + regular_config.att_sparse = False + regular_config.mlp_sparse = False + regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype) + regular_block.eval() + + # inference simulation with select block + max_seqlen = args.seq_len + 128 + max_batch_size = args.batch_size + in_features = args.in_features + head_dim = 128 + batch_size = args.batch_size + seq_len = args.seq_len + index_size = args.index_size + + inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size) + process_group = None + sequence_parallel = False + + # for testing and debugging + heads = config.num_attention_heads + selected_heads = heads // 2 + + # Create a static index vector (length equals total columns in B). + total_neurons = args.in_features * 4 + test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32) + active_indices = sparse_index(args.index_size, total_neurons)[0] + test_index_vec[:args.index_size] = active_indices + if args.index_size < total_neurons: + test_index_vec[args.index_size:] = 0 # Fill the rest with dummy values. + + # test_index_vec = sparse_index(args.in_features, args.in_features*4)[0].cuda() + test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads) + test_index_size = args.index_size + + mixer_kwargs = ( + {"seqlen": seq_len} + if process_group is not None and sequence_parallel + else {} + ) + if inference_params is not None: + mixer_kwargs["inference_params"] = inference_params + + with torch.no_grad(): + # prefill stage + original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16) + + # Test prefill + # output_sparse = sparse_block(original_seq, mixer_kwargs=mixer_kwargs) + # output_regular = regular_block(original_seq, mixer_kwargs=mixer_kwargs) + + # simulate the kv cache + kv = torch.rand(batch_size, seq_len, 2, heads, head_dim, device='cuda', dtype=torch.float16) + + # need to update inference_params to reflect the new sequence length + sparse_block.mixer._update_kv_cache(kv, inference_params) + regular_block.mixer._update_kv_cache(kv, inference_params) + mixer_kwargs["inference_params"].seqlen_offset = seq_len + + # Decode stage + input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16) + + out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs) + + mixer_kwargs["inference_params"].seqlen_offset = seq_len + + out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs) + + # mesure decode stage time in ms + print("Without CUDA Graphs") + out_decode_regular, regular_time = cuda_profiler(regular_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2) + print(f"Regular time: {regular_time} ms") + + out_decode_sparse, sparse_time = cuda_profiler(sparse_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2) + print(f"Sparse time: {sparse_time} ms") + + speedup = regular_time / sparse_time + print(f"Speedup: {speedup}") + + # --- CUDA Graph Capture for Decode Stage --- + # Allocate static buffer for regular block (shape assumed fixed) + input_x_static = input_x.clone() + output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype) + + # Capture regular block graph + _ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) + torch.cuda.synchronize() + graph_regular = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph_regular): + res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) + if isinstance(res, tuple): + res = res[0] + output_regular_static.copy_(res) + + # For the sparse block, run a dummy call to determine its output shape. + # Also, reset the inference parameter to ensure consistent behavior. + mixer_kwargs["inference_params"].seqlen_offset = seq_len + temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) + if isinstance(temp, tuple): + temp = temp[0] + # print("Captured sparse block output shape:", temp.shape) + # Allocate static buffer matching the dummy run's shape. + output_sparse_static = torch.empty_like(temp) + # print("output_sparse_static shape:", output_sparse_static.shape) + torch.cuda.synchronize() + + mixer_kwargs["inference_params"].seqlen_offset = seq_len + graph_sparse = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph_sparse): + res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) + if isinstance(res, tuple): + res = res[0] + output_sparse_static.copy_(res) + + # Warmup CUDA Graph replays + for _ in range(5): + graph_regular.replay() + graph_sparse.replay() + torch.cuda.synchronize() + + # --- Measure CUDA Graph Replay Latency --- + num_replays = 10 + + start = time.time() + for _ in range(num_replays): + graph_regular.replay() + torch.cuda.synchronize() + regular_graph_time = (time.time() - start) * 1000 / num_replays + + start = time.time() + for _ in range(num_replays): + graph_sparse.replay() + torch.cuda.synchronize() + sparse_graph_time = (time.time() - start) * 1000 / num_replays + + print() + print("With CUDA Graphs") + print(f"Regular block time (CUDA Graphs): {regular_graph_time} ms") + print(f"Sparse block time (CUDA Graphs): {sparse_graph_time} ms") + print(f"Speedup (CUDA Graphs): {regular_graph_time/sparse_graph_time}") + + # Compare Outputs from Eager and CUDA Graph Versions + if args.check_results: + if isinstance(out_decode_regular, tuple): + out_decode_regular = out_decode_regular[0] + regular_match = torch.allclose(out_decode_regular, output_regular_static, rtol=1e-3, atol=1e-5) + reg_diff = (out_decode_regular - output_regular_static).abs().max() + # print both the outputs results + # print(f"out_decode_regular: {out_decode_regular}") + # print(f"output_regular_static: {output_regular_static}") + + print("\nComparison for Regular Block:") + print(f"Outputs match: {regular_match}") + print(f"Max difference: {reg_diff}") + + if isinstance(out_decode_sparse, tuple): + out_decode_sparse = out_decode_sparse[0] + sparse_match = torch.allclose(out_decode_sparse, output_sparse_static, rtol=1e-3, atol=1e-5) + spa_diff = (out_decode_sparse - output_sparse_static).abs().max() + print("\nComparison for Sparse Block:") + print(f"Outputs match: {sparse_match}") + print(f"Max difference: {spa_diff}") + +