| |
|
| |
|
| |
|
| |
|
| |
|
| | import itertools
|
| | import sys
|
| | import time
|
| | from pathlib import Path
|
| | from typing import Optional, Tuple
|
| |
|
| | import torch
|
| | import torch._dynamo.config
|
| | import torch._inductor.config
|
| |
|
| | def device_sync(device):
|
| | if "cuda" in device:
|
| | torch.cuda.synchronize(device)
|
| | elif ("cpu" in device) or ("mps" in device):
|
| | pass
|
| | else:
|
| | print(f"device={device} is not yet suppported")
|
| |
|
| |
|
| | torch._inductor.config.coordinate_descent_tuning = True
|
| | torch._inductor.config.triton.unique_kernel_names = True
|
| | torch._inductor.config.fx_graph_cache = True
|
| |
|
| | default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| |
|
| |
|
| | wd = Path(__file__).parent.parent.resolve()
|
| | sys.path.append(str(wd))
|
| |
|
| | from model import Transformer
|
| | from tokenizer import get_tokenizer
|
| |
|
| | def multinomial_sample_one_no_sync(probs_sort):
|
| | q = torch.empty_like(probs_sort).exponential_(1)
|
| | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
| |
|
| | def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
| | logits = logits / max(temperature, 1e-5)
|
| |
|
| | if top_k is not None:
|
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| | pivot = v.select(-1, -1).unsqueeze(-1)
|
| | logits = torch.where(logits < pivot, -float("Inf"), logits)
|
| | probs = torch.nn.functional.softmax(logits, dim=-1)
|
| | return probs
|
| |
|
| | def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
| | probs = logits_to_probs(logits[0, -1], temperature, top_k)
|
| | idx_next = multinomial_sample_one_no_sync(probs)
|
| | return idx_next, probs
|
| |
|
| | def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
|
| |
|
| | logits = model(x, input_pos)
|
| | return sample(logits, **sampling_kwargs)[0]
|
| |
|
| | def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
| |
|
| | assert input_pos.shape[-1] == 1
|
| | logits = model(x, input_pos)
|
| | return sample(logits, **sampling_kwargs)
|
| |
|
| | def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
|
| | new_tokens, new_probs = [], []
|
| | for i in range(num_new_tokens):
|
| | with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
| | next_token, next_prob = decode_one_token(
|
| | model, cur_token, input_pos, **sampling_kwargs
|
| | )
|
| | input_pos += 1
|
| | new_tokens.append(next_token.clone())
|
| | callback(new_tokens[-1])
|
| | new_probs.append(next_prob.clone())
|
| | cur_token = next_token.view(1, -1)
|
| |
|
| | return new_tokens, new_probs
|
| |
|
| |
|
| | def model_forward(model, x, input_pos):
|
| | return model(x, input_pos)
|
| |
|
| | def speculative_decode(
|
| | model: Transformer,
|
| | draft_model: Transformer,
|
| | cur_token: torch.Tensor,
|
| | input_pos: int,
|
| | speculate_k: int,
|
| | **sampling_kwargs
|
| | ) -> torch.Tensor:
|
| |
|
| | device = cur_token.device
|
| | orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
|
| | draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
|
| |
|
| | draft_tokens = torch.cat(draft_tokens)
|
| |
|
| | target_logits = model_forward(
|
| | model,
|
| | torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
|
| | torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
|
| | )
|
| | target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
|
| | draft_probs = torch.stack(draft_probs)
|
| |
|
| |
|
| |
|
| | p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
|
| | q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
|
| | accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
|
| | rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()
|
| |
|
| | if rejected_locations.shape[0] == 0:
|
| | accept_length = speculate_k + 1
|
| | last_token = multinomial_sample_one_no_sync(target_probs[-1])
|
| |
|
| | model_forward(
|
| | draft_model,
|
| | draft_tokens[-1].view(1, -1),
|
| | orig_input_pos + speculate_k,
|
| | )
|
| | return torch.cat([draft_tokens, last_token])
|
| | else:
|
| | accept_length = rejected_locations[0].item()
|
| | p = draft_probs[accept_length]
|
| | q = target_probs[accept_length]
|
| | new = q - p
|
| | new = torch.where(new > 0, new, 0.0)
|
| | new = new / new.sum()
|
| | next_token = multinomial_sample_one_no_sync(new)
|
| | return torch.cat([draft_tokens[:accept_length], next_token])
|
| |
|
| | @torch.no_grad()
|
| | def generate(
|
| | model: Transformer,
|
| | prompt: torch.Tensor,
|
| | max_new_tokens: int,
|
| | *,
|
| | interactive: bool,
|
| | draft_model: Transformer,
|
| | speculate_k: Optional[int] = 8,
|
| | callback = lambda x: x,
|
| | **sampling_kwargs
|
| | ) -> torch.Tensor:
|
| | """
|
| | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
| | """
|
| |
|
| | is_speculative = draft_model is not None
|
| |
|
| | T = prompt.size(0)
|
| | T_new = T + max_new_tokens
|
| | if interactive:
|
| | max_seq_length = 350
|
| | else:
|
| | max_seq_length = min(T_new, model.config.block_size)
|
| |
|
| | device, dtype = prompt.device, prompt.dtype
|
| | max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
|
| | with torch.device(device):
|
| | model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
|
| | if is_speculative and draft_model is not model:
|
| | draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
|
| |
|
| |
|
| | empty = torch.empty(T_new, dtype=dtype, device=device)
|
| | empty[:T] = prompt
|
| | seq = empty
|
| | input_pos = torch.arange(0, T, device=device)
|
| |
|
| | next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
|
| | if is_speculative:
|
| | prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
|
| | seq[T] = next_token
|
| |
|
| | input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
| | accept_counts = [0] * (speculate_k + 1)
|
| |
|
| | if is_speculative:
|
| | input_pos = input_pos.item()
|
| | while input_pos < T_new - 1:
|
| | cur_token = next_token.view(())
|
| |
|
| | next_tokens = speculative_decode(
|
| | model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
|
| | )
|
| |
|
| | accept_counts[len(next_tokens) - 1] += 1
|
| | num_added = min(T_new - input_pos - 1, len(next_tokens))
|
| | seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added]
|
| | for i in next_tokens[: num_added,]:
|
| | callback(i)
|
| | input_pos = input_pos + num_added
|
| | next_token = next_tokens[-1]
|
| | else:
|
| | generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
|
| | seq[T + 1:] = torch.cat(generated_tokens)
|
| |
|
| | generate_stats = {
|
| | 'accept_counts': accept_counts
|
| | }
|
| | return seq, generate_stats
|
| |
|
| | def encode_tokens(tokenizer, string, bos=True, device=default_device):
|
| | tokens = tokenizer.encode(string)
|
| | if bos:
|
| | tokens = [tokenizer.bos_id()] + tokens
|
| | return torch.tensor(tokens, dtype=torch.int, device=device)
|
| |
|
| | def _load_model(checkpoint_path, device, precision, use_tp):
|
| | use_cuda = 'cuda' in device
|
| | with torch.device('meta'):
|
| | model = Transformer.from_name(checkpoint_path.parent.name)
|
| |
|
| | if "int8" in str(checkpoint_path):
|
| | print("Using int8 weight-only quantization!")
|
| | from quantize import WeightOnlyInt8QuantHandler
|
| | simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
| | model = simple_quantizer.convert_for_runtime()
|
| |
|
| | if "int4" in str(checkpoint_path):
|
| | print("Using int4 weight-only quantization!")
|
| | path_comps = checkpoint_path.name.split(".")
|
| | groupsize = int(path_comps[-2][1:])
|
| | from quantize import WeightOnlyInt4QuantHandler
|
| | simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
|
| | model = simple_quantizer.convert_for_runtime()
|
| |
|
| | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
| | if "model" in checkpoint and "stories" in str(checkpoint_path):
|
| | checkpoint = checkpoint["model"]
|
| | model.load_state_dict(checkpoint, assign=True)
|
| |
|
| | if use_tp:
|
| | from tp import apply_tp
|
| | print("Applying tensor parallel to model ...")
|
| | apply_tp(model)
|
| |
|
| | model = model.to(device=device, dtype=precision)
|
| | return model.eval()
|
| |
|
| | def _get_model_size(model):
|
| | model_size = 0
|
| | for name, child in model.named_children():
|
| | if not isinstance(child, torch.nn.Embedding):
|
| | model_size += sum(
|
| | [
|
| | p.numel() * p.dtype.itemsize
|
| | for p in itertools.chain(child.parameters(), child.buffers())
|
| | ]
|
| | )
|
| | return model_size
|
| |
|
| | B_INST, E_INST = "[INST]", "[/INST]"
|
| |
|
| | def main(
|
| | prompt: str = "Hello, my name is",
|
| | interactive: bool = False,
|
| | num_samples: int = 5,
|
| | max_new_tokens: int = 100,
|
| | top_k: int = 200,
|
| | temperature: float = 0.8,
|
| | checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
|
| | compile: bool = True,
|
| | compile_prefill: bool = False,
|
| | profile: Optional[Path] = None,
|
| | draft_checkpoint_path: Optional[Path] = None,
|
| | speculate_k: int = 5,
|
| | device=default_device,
|
| | ) -> None:
|
| | """Generates text samples based on a pre-trained Transformer model and tokenizer.
|
| | """
|
| | assert checkpoint_path.is_file(), checkpoint_path
|
| |
|
| | tokenizer_path = checkpoint_path.parent / "tokenizer.model"
|
| | assert tokenizer_path.is_file(), str(tokenizer_path)
|
| |
|
| | global print
|
| | from tp import maybe_init_dist
|
| | rank = maybe_init_dist()
|
| | use_tp = rank is not None
|
| | if use_tp:
|
| | if rank != 0:
|
| |
|
| | print = lambda *args, **kwargs: None
|
| |
|
| | print(f"Using device={device}")
|
| | precision = torch.bfloat16
|
| | is_speculative = draft_checkpoint_path is not None
|
| | is_chat = "chat" in str(checkpoint_path)
|
| |
|
| | print("Loading model ...")
|
| | t0 = time.time()
|
| | model = _load_model(checkpoint_path, device, precision, use_tp)
|
| |
|
| | if is_speculative:
|
| | draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
|
| | else:
|
| | draft_model = None
|
| |
|
| | device_sync(device=device)
|
| | print(f"Time to load model: {time.time() - t0:.02f} seconds")
|
| |
|
| | tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
|
| |
|
| | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
|
| | prompt_length = encoded.size(0)
|
| |
|
| | torch.manual_seed(1234)
|
| | model_size = _get_model_size(model)
|
| | if compile:
|
| | if is_speculative and use_tp:
|
| | torch._inductor.config.triton.cudagraph_trees = False
|
| |
|
| | if is_speculative:
|
| | global model_forward, logits_to_prob
|
| | model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
|
| |
|
| | global decode_one_token, prefill
|
| | decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
|
| |
|
| |
|
| | if compile_prefill:
|
| | prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
|
| |
|
| |
|
| | aggregate_metrics = {
|
| | 'tokens_per_sec': [],
|
| | 'accept_counts': [],
|
| | }
|
| | start = -1 if compile else 0
|
| |
|
| | for i in range(start, num_samples):
|
| | device_sync(device=device)
|
| | if i >= 0 and interactive:
|
| | prompt = input("What is your prompt? ")
|
| | if is_chat:
|
| | prompt = f"{B_INST} {prompt.strip()} {E_INST}"
|
| | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
|
| |
|
| | if interactive and i >= 0:
|
| | buffer = []
|
| | period_id = tokenizer.encode('.')[0]
|
| | done_generating = False
|
| | def callback(x):
|
| | nonlocal done_generating
|
| | if done_generating:
|
| | return
|
| | buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
|
| | if x.item() == tokenizer.eos_id():
|
| | done_generating = True
|
| | if len(buffer) == 4 or done_generating:
|
| | print(''.join(buffer), end='', flush=True)
|
| | buffer.clear()
|
| |
|
| | else:
|
| | callback = lambda x : x
|
| | t0 = time.perf_counter()
|
| | import contextlib
|
| | if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
|
| | prof = contextlib.nullcontext()
|
| | else:
|
| | torch.profiler._utils._init_for_cuda_graphs()
|
| | prof = torch.profiler.profile()
|
| | with prof:
|
| | y, metrics = generate(
|
| | model,
|
| | encoded,
|
| | max_new_tokens,
|
| | draft_model=draft_model,
|
| | speculate_k=speculate_k,
|
| | interactive=interactive,
|
| | callback=callback,
|
| | temperature=temperature,
|
| | top_k=top_k,
|
| | )
|
| | aggregate_metrics['accept_counts'].append(metrics['accept_counts'])
|
| | if i == -1:
|
| | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
| | continue
|
| | if hasattr(prof, "export_chrome_trace"):
|
| | if use_tp:
|
| | prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
|
| | else:
|
| | prof.export_chrome_trace(f"{profile}.json")
|
| | device_sync(device=device)
|
| | t = time.perf_counter() - t0
|
| |
|
| | if not interactive:
|
| | print(tokenizer.decode(y.tolist()))
|
| | else:
|
| | print()
|
| | tokens_generated = y.size(0) - prompt_length
|
| | tokens_sec = tokens_generated / t
|
| | aggregate_metrics['tokens_per_sec'].append(tokens_sec)
|
| | print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
|
| | print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
|
| | print("==========")
|
| | if is_speculative:
|
| | counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
|
| | acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
|
| | print(f"Acceptance probs: {acceptance_probs}")
|
| | print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
|
| |
|
| | print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
|
| | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
|
| |
|
| |
|
| | if __name__ == '__main__':
|
| | import argparse
|
| | parser = argparse.ArgumentParser(description='Your CLI description.')
|
| |
|
| | parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
|
| | parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
|
| | parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
|
| | parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
|
| | parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
|
| | parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
|
| | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
|
| | parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
|
| | parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
|
| | parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
|
| | parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
|
| | parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
|
| | parser.add_argument('--device', type=str, default=default_device, help='Device to use')
|
| |
|
| | args = parser.parse_args()
|
| | main(
|
| | args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
|
| | args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
|
| | args.speculate_k, args.device
|
| | )
|
| |
|