Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import json | |
| import os | |
| import readline # type: ignore # noqa | |
| import sys | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterable, Optional, Tuple, Union | |
| import model as fast | |
| import mp_utils | |
| import sample_utils | |
| import torch | |
| from stats import Stats | |
| from tokenizer import Tokenizer | |
| from xformers.ops.fmha.attn_bias import ( | |
| BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, | |
| ) | |
| class GenArgs: | |
| gen_length: int = 1000 | |
| use_sampling: bool = True | |
| temperature: float = 0.6 | |
| top_p: float = 0.9 | |
| class FastGen: | |
| GRAPH_WARMUPS: int = 3 | |
| tokenizer: Tokenizer | |
| def build( | |
| ckpt_dir: str, | |
| gen_args: GenArgs, | |
| device: Union[torch.device, str], | |
| tokenizer_path: Optional[str] = None, | |
| ) -> "FastGen": | |
| """ | |
| Load a Llama or Code Llama checkpoint and return a new | |
| generator for this model. | |
| """ | |
| start_time = time.time() | |
| world_size = mp_utils.get_world_size() | |
| checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) | |
| assert len(checkpoints) > 0, f"no checkpoint files in {ckpt_dir}" | |
| assert world_size == len(checkpoints), ( | |
| f"checkpoint for model parallelism {len(checkpoints)}" | |
| f" but world size is {world_size}" | |
| ) | |
| ckpt_path = checkpoints[mp_utils.get_rank()] | |
| with open(Path(ckpt_dir) / "params.json", "r") as f: | |
| params = json.loads(f.read()) | |
| model_args = fast.ModelArgs(**params) | |
| if tokenizer_path is None: | |
| tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") | |
| if not os.path.isfile(tokenizer_path): | |
| tokenizer_path = str(Path(ckpt_dir) / ".." / "tokenizer.model") | |
| if not os.path.isfile(tokenizer_path): | |
| raise RuntimeError("could not find the tokenizer model") | |
| tokenizer = Tokenizer(model_path=tokenizer_path) | |
| model_args.vocab_size = tokenizer.n_words | |
| torch.set_default_device(device) | |
| torch.set_default_dtype(torch.bfloat16) | |
| model = fast.Transformer(model_args) | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| model.load_state_dict(checkpoint, strict=False) | |
| print(f"loaded model in {time.time() - start_time:.2f} seconds") | |
| return FastGen(gen_args, model_args, model, tokenizer) | |
| def __init__( | |
| self, | |
| args: GenArgs, | |
| model_args: fast.ModelArgs, | |
| model: fast.Transformer, | |
| tokenizer: Tokenizer, | |
| ): | |
| self.gen_args = args | |
| self.model_args = model_args | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| def generate_all( | |
| self, prompts: list[list[int]], use_cuda_graphs: bool | |
| ) -> Tuple[Stats, list[list[int]]]: | |
| bs = len(prompts) | |
| prompt_lens = [len(p) for p in prompts] | |
| max_prompt_length = max(prompt_lens) | |
| gen_length = self.gen_args.gen_length | |
| max_seq_length = max_prompt_length + gen_length | |
| cache = fast.make_cache( | |
| args=self.model_args, | |
| length=bs * max_seq_length, | |
| ) | |
| bias = AttnBias.from_seqlens( | |
| q_seqlen=prompt_lens, | |
| kv_seqlen=prompt_lens, | |
| kv_padding=max_seq_length, | |
| ) | |
| bias.q_seqinfo.to("cuda") | |
| bias.k_seqinfo.to("cuda") | |
| graph = torch.cuda.CUDAGraph() | |
| # Input tensors to the cuda graph | |
| q_seqstart = bias.q_seqinfo.seqstart | |
| kv_seqlen = bias.k_seqinfo.seqlen | |
| tokens = torch.IntTensor(sum(prompts, [])).cuda() | |
| out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int) | |
| stats = Stats() | |
| stats.phase("warmup" if use_cuda_graphs else "total") | |
| for niter in range(gen_length): | |
| if niter <= self.GRAPH_WARMUPS or not use_cuda_graphs: | |
| # Keep the first iteration out of the | |
| # warmup, it processes prompts while all | |
| # other iterations process sequences of 0 | |
| # or 1 token only | |
| output = self.model.forward_with_attn_bias( | |
| token_values=tokens, | |
| attn_bias=bias, | |
| cache=cache, | |
| ) | |
| elif niter == self.GRAPH_WARMUPS + 1: | |
| recording_kwargs = {} | |
| if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__: | |
| # In PyTorch 2.1+ and nightlies from late Aug 2023, | |
| # we can do this to maybe avoid watchdog-related crashes | |
| recording_kwargs["capture_error_mode"] = "thread_local" | |
| with torch.cuda.graph(graph, **recording_kwargs): | |
| output = self.model.forward_with_attn_bias( | |
| token_values=tokens, | |
| attn_bias=bias, | |
| cache=cache, | |
| ) | |
| graph.replay() | |
| # synchronize to get accurate timings | |
| torch.cuda.synchronize() | |
| stats.phase("graph", tokens=(niter + 1) * bs) | |
| else: | |
| graph.replay() | |
| # output: (sum(token_lengths), vocab_size) | |
| logits = output.view(bs, self.model_args.vocab_size) | |
| if self.gen_args.use_sampling: | |
| temp = self.gen_args.temperature | |
| top_p = self.gen_args.top_p | |
| probs = torch.softmax(logits / temp, dim=-1) | |
| next_token = sample_utils.top_p(probs, top_p) | |
| else: | |
| next_token = torch.argmax(logits, dim=-1) | |
| next_token = next_token.reshape(bs) | |
| out_tokens[niter, :] = next_token | |
| # Update attention bias state for decoding rounds | |
| if niter == 0: | |
| q_seqstart.copy_(torch.arange(bs + 1, dtype=torch.int)) | |
| bias.q_seqinfo.min_seqlen = 1 | |
| bias.q_seqinfo.max_seqlen = 1 | |
| bias.q_seqinfo.seqstart_py = q_seqstart.tolist() | |
| tokens = tokens[:bs] | |
| kv_seqlen.add_(kv_seqlen < max_seq_length) | |
| tokens.copy_(next_token) | |
| stats.end_phase(tokens=gen_length * bs) | |
| def trim_answer(prompt, tokens): | |
| """Trim the answer to end it on an eos token.""" | |
| tokens = tokens[: max_seq_length - len(prompt)] | |
| eos_id = self.tokenizer.eos_id | |
| if eos_id in tokens: | |
| return tokens[: tokens.index(eos_id) + 1] | |
| else: | |
| return tokens | |
| answers = [ | |
| trim_answer(prompt, answer) | |
| for prompt, answer in zip(prompts, out_tokens.t().tolist()) | |
| ] | |
| return stats, answers | |
| def get_prompts(interactive: bool) -> Iterable[list[str]]: | |
| if interactive: | |
| while True: | |
| try: | |
| prompts = input("enter prompt: ").split("\n") | |
| except EOFError: | |
| print("exiting") | |
| sys.exit(0) | |
| yield prompts | |
| else: | |
| yield [ | |
| "abc", | |
| "can you write a hello world program in C#", | |
| "peux tu resoudre le probleme des tours de Hanoi en ocaml", | |
| ] | |
| def main(ckpt_dir: str, interactive: bool, add_instruction_tags: bool): | |
| if "WORLD_SIZE" in os.environ: | |
| mp_size = int(os.environ["WORLD_SIZE"]) | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| else: | |
| mp_size = 1 | |
| local_rank = 0 | |
| device = mp_utils.initialize(mp_size, local_rank) | |
| g = FastGen.build(ckpt_dir, GenArgs(), device) | |
| for prompts in get_prompts(interactive): | |
| if add_instruction_tags: | |
| prompts = [f"[INST]{prompt}[/INST]" for prompt in prompts] | |
| tokens = [g.tokenizer.encode(x) for x in prompts] | |
| stats, out_tokens = g.generate_all( | |
| tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ | |
| ) | |
| if mp_utils.get_rank() == 0: | |
| for i, prompt in enumerate(prompts): | |
| print(f"> {prompt}") | |
| answer = g.tokenizer.decode(out_tokens[i]) | |
| print(answer) | |
| print("---------------") | |
| for phase_stats in stats.phases: | |
| print(phase_stats.show()) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser("Llama inference") | |
| parser.add_argument("ckpt_dir") | |
| parser.add_argument( | |
| "-i", "--interactive", action="store_true", help="ask for prompts" | |
| ) | |
| parser.add_argument( | |
| "--no-instruction-tags", action="store_true", help="do not add instruction tags" | |
| ) | |
| args = parser.parse_args() | |
| main( | |
| ckpt_dir=args.ckpt_dir, | |
| interactive=args.interactive, | |
| add_instruction_tags=not args.no_instruction_tags, | |
| ) | |