| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from searshorai.model import GPT, GPTConfig |
| from searshorai.tokenizer import TextTokenizer |
|
|
|
|
| |
| |
| DEFAULT_PROMPT_TEMPLATE = ( |
| "Read the article and write a one-sentence summary.\n\n" |
| "Article:\n{passage}\n\nSummary:\n" |
| ) |
|
|
|
|
| def strip_compile_prefix(state_dict): |
| cleaned = {} |
| for key, value in state_dict.items(): |
| if key.startswith("_orig_mod."): |
| key = key[len("_orig_mod.") :] |
| cleaned[key] = value |
| return cleaned |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Ask the paragraph-explainer model.") |
| parser.add_argument("--checkpoint", type=Path, required=True) |
| parser.add_argument("--tokenizer", type=Path, default=Path("data/wikitext103/tokenizer.json")) |
| parser.add_argument("--text", type=str, required=True, help="The passage to explain.") |
| parser.add_argument("--prompt_template", type=str, default=DEFAULT_PROMPT_TEMPLATE) |
| parser.add_argument("--max_new_tokens", type=int, default=120) |
| parser.add_argument("--temperature", type=float, default=0.7) |
| parser.add_argument("--top_k", type=int, default=40) |
| parser.add_argument("--top_p", type=float, default=0.9, |
| help="Nucleus sampling cutoff. Set 1.0 to disable.") |
| parser.add_argument("--repetition_penalty", type=float, default=1.3, |
| help="Penalty for re-emitting tokens already in the context. 1.0 = off.") |
| parser.add_argument("--no_repeat_ngram_size", type=int, default=3, |
| help="Block any n-gram of this size from appearing twice. 0 = off.") |
| parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"]) |
| parser.add_argument("--seed", type=int, default=0) |
| return parser.parse_args() |
|
|
|
|
| def banned_tokens_from_ngrams(generated: list[int], n: int) -> set[int]: |
| """ |
| For no-repeat-ngram blocking: given the tokens generated so far, return |
| the set of token ids that would close a previously-seen n-gram if emitted |
| next. |
| """ |
| if n <= 0 or len(generated) < n - 1: |
| return set() |
| prefix = tuple(generated[-(n - 1):]) |
| banned: set[int] = set() |
| for i in range(len(generated) - n + 1): |
| ngram = tuple(generated[i : i + n - 1]) |
| if ngram == prefix: |
| banned.add(generated[i + n - 1]) |
| return banned |
|
|
|
|
| def generate( |
| model: GPT, |
| prompt_ids: list[int], |
| max_new_tokens: int, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| repetition_penalty: float, |
| no_repeat_ngram_size: int, |
| eos_id: int | None, |
| device: str, |
| ) -> list[int]: |
| """ |
| Custom sampling loop with repetition penalty, top-k, top-p (nucleus), |
| and no-repeat-ngram blocking. Returns the list of newly generated token |
| ids (does not include the prompt). |
| """ |
| block_size = model.config.block_size |
| context = list(prompt_ids) |
| generated: list[int] = [] |
|
|
| for _ in range(max_new_tokens): |
| idx_cond = context if len(context) <= block_size else context[-block_size:] |
| x = torch.tensor([idx_cond], dtype=torch.long, device=device) |
|
|
| with torch.no_grad(): |
| logits, _ = model(x) |
| logits = logits[:, -1, :].squeeze(0).float() |
|
|
| if repetition_penalty != 1.0 and len(context) > 0: |
| seen = torch.tensor(list(set(context)), dtype=torch.long, device=device) |
| scores = logits[seen] |
| scores = torch.where(scores > 0, scores / repetition_penalty, scores * repetition_penalty) |
| logits[seen] = scores |
|
|
| if no_repeat_ngram_size > 0 and len(generated) >= no_repeat_ngram_size - 1: |
| banned = banned_tokens_from_ngrams(generated, no_repeat_ngram_size) |
| for tok_id in banned: |
| logits[tok_id] = -float("inf") |
|
|
| logits = logits / max(temperature, 1e-5) |
|
|
| if top_k is not None and top_k > 0: |
| k = min(top_k, logits.size(-1)) |
| top_vals, _ = torch.topk(logits, k) |
| cutoff = top_vals[-1] |
| logits = torch.where(logits < cutoff, torch.full_like(logits, -float("inf")), logits) |
|
|
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| probs_sorted = F.softmax(sorted_logits, dim=-1) |
| cumulative = torch.cumsum(probs_sorted, dim=-1) |
| mask = cumulative > top_p |
| mask[..., 1:] = mask[..., :-1].clone() |
| mask[..., 0] = False |
| sorted_logits = sorted_logits.masked_fill(mask, -float("inf")) |
| logits = torch.full_like(logits, -float("inf")) |
| logits.scatter_(0, sorted_idx, sorted_logits) |
|
|
| probs = F.softmax(logits, dim=-1) |
| if not torch.isfinite(probs).all() or probs.sum() <= 0: |
| next_tok = int(logits.argmax().item()) |
| else: |
| next_tok = int(torch.multinomial(probs, num_samples=1).item()) |
|
|
| if eos_id is not None and next_tok == eos_id: |
| break |
|
|
| context.append(next_tok) |
| generated.append(next_tok) |
|
|
| return generated |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| if args.seed: |
| torch.manual_seed(args.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(args.seed) |
|
|
| if args.device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| else: |
| device = args.device |
|
|
| tok = TextTokenizer(args.tokenizer) |
|
|
| ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False) |
| config = GPTConfig(**ckpt["config"]) |
| config.dropout = 0.0 |
| config.gradient_checkpointing = False |
|
|
| model = GPT(config) |
| state_dict = strip_compile_prefix(ckpt["model"]) |
| model.load_state_dict(state_dict, strict=True) |
| model.to(device) |
| model.eval() |
|
|
| if tok.vocab_size != model.config.vocab_size: |
| raise RuntimeError( |
| f"Tokenizer vocab_size {tok.vocab_size} != model vocab_size {model.config.vocab_size}. " |
| "Use the same tokenizer.json that was used for pretrain/SFT." |
| ) |
|
|
| prompt = args.prompt_template.format(passage=args.text.strip()) |
| prompt_ids = tok.encode(prompt, add_bos=True, add_eos=False) |
|
|
| max_prompt_len = model.config.block_size - args.max_new_tokens - 1 |
| if max_prompt_len < 16: |
| raise RuntimeError( |
| f"max_new_tokens={args.max_new_tokens} is too large for block_size={model.config.block_size}." |
| ) |
| if len(prompt_ids) > max_prompt_len: |
| bos = [prompt_ids[0]] if prompt_ids and prompt_ids[0] == tok.bos_id else [] |
| tail = prompt_ids[-(max_prompt_len - len(bos)) :] |
| prompt_ids = bos + tail |
|
|
| new_ids = generate( |
| model=model, |
| prompt_ids=prompt_ids, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| no_repeat_ngram_size=args.no_repeat_ngram_size, |
| eos_id=tok.eos_id, |
| device=device, |
| ) |
|
|
| answer = tok.decode(new_ids, skip_special_tokens=True).strip() |
| print(answer) |
|
|
|
|
| if __name__ == "__main__": |
| main() |