| |
| import argparse, sys |
| from pathlib import Path |
| import torch |
| from tokenizers import ByteLevelBPETokenizer |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
| from src.tinyllm.config import TinyConfig |
| from src.tinyllm.model import TinyLlamaForCausalLM |
|
|
|
|
| def make_prompt(user_prompt: str, system: str) -> str: |
| return f"<|system|>\n{system}\n<|end|>\n<|user|>\n{user_prompt}\n<|end|>\n<|assistant|>\n" |
|
|
|
|
| def sample_next(logits, temperature: float, top_k: int): |
| logits = logits.float() |
| if temperature <= 0: |
| return int(torch.argmax(logits)) |
| logits = logits / temperature |
| if top_k and top_k > 0: |
| vals, idx = torch.topk(logits, min(top_k, logits.numel())) |
| probs = torch.softmax(vals, dim=-1) |
| return int(idx[torch.multinomial(probs, 1)]) |
| probs = torch.softmax(logits, dim=-1) |
| return int(torch.multinomial(probs, 1)) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--checkpoint', default='final.pt') |
| ap.add_argument('--config', default='configs/model_75m.yaml') |
| ap.add_argument('--tokenizer-dir', default='tokenizer') |
| ap.add_argument('--prompt', required=True) |
| ap.add_argument('--system', default='You are a helpful, concise assistant.') |
| ap.add_argument('--max-new-tokens', type=int, default=80) |
| ap.add_argument('--temperature', type=float, default=0.6) |
| ap.add_argument('--top-k', type=int, default=40) |
| args = ap.parse_args() |
|
|
| tok_dir = Path(args.tokenizer_dir) |
| tok = ByteLevelBPETokenizer(str(tok_dir / 'vocab.json'), str(tok_dir / 'merges.txt')) |
| cfg = TinyConfig.from_yaml(args.config) |
| model = TinyLlamaForCausalLM(cfg) |
| ckpt = torch.load(args.checkpoint, map_location='cpu') |
| state = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt |
| model.load_state_dict(state, strict=False) |
| model.eval() |
|
|
| ids = tok.encode(make_prompt(args.prompt, args.system)).ids |
| prompt_len = len(ids) |
| end_id = tok.token_to_id('<|end|>') |
| for _ in range(args.max_new_tokens): |
| x = torch.tensor([ids[-cfg.max_position_embeddings:]], dtype=torch.long) |
| with torch.no_grad(): |
| logits = model(x)['logits'][0, -1] |
| nxt = sample_next(logits, args.temperature, args.top_k) |
| ids.append(nxt) |
| if end_id is not None and nxt == end_id: |
| break |
|
|
| text = tok.decode(ids[prompt_len:]) |
| for marker in ['<|end|>', '<|user|>', '<|assistant|>', '<|system|>']: |
| text = text.split(marker)[0] |
| print(text.strip()) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|