from __future__ import annotations import argparse from pathlib import Path import torch import torch.nn.functional as F from searshorai.model import GPT, GPTConfig from searshorai.tokenizer import TextTokenizer # Must match the prompts used in make_xsum_sft.py / make_paragraph_sft.py. # Using the first template is the canonical choice at inference time. DEFAULT_PROMPT_TEMPLATE = ( "Read the article and write a one-sentence summary.\n\n" "Article:\n{passage}\n\nSummary:\n" ) def strip_compile_prefix(state_dict): cleaned = {} for key, value in state_dict.items(): if key.startswith("_orig_mod."): key = key[len("_orig_mod.") :] cleaned[key] = value return cleaned def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Ask the paragraph-explainer model.") parser.add_argument("--checkpoint", type=Path, required=True) parser.add_argument("--tokenizer", type=Path, default=Path("data/wikitext103/tokenizer.json")) parser.add_argument("--text", type=str, required=True, help="The passage to explain.") parser.add_argument("--prompt_template", type=str, default=DEFAULT_PROMPT_TEMPLATE) parser.add_argument("--max_new_tokens", type=int, default=120) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top_k", type=int, default=40) parser.add_argument("--top_p", type=float, default=0.9, help="Nucleus sampling cutoff. Set 1.0 to disable.") parser.add_argument("--repetition_penalty", type=float, default=1.3, help="Penalty for re-emitting tokens already in the context. 1.0 = off.") parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="Block any n-gram of this size from appearing twice. 0 = off.") parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"]) parser.add_argument("--seed", type=int, default=0) return parser.parse_args() def banned_tokens_from_ngrams(generated: list[int], n: int) -> set[int]: """ For no-repeat-ngram blocking: given the tokens generated so far, return the set of token ids that would close a previously-seen n-gram if emitted next. """ if n <= 0 or len(generated) < n - 1: return set() prefix = tuple(generated[-(n - 1):]) banned: set[int] = set() for i in range(len(generated) - n + 1): ngram = tuple(generated[i : i + n - 1]) if ngram == prefix: banned.add(generated[i + n - 1]) return banned def generate( model: GPT, prompt_ids: list[int], max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, no_repeat_ngram_size: int, eos_id: int | None, device: str, ) -> list[int]: """ Custom sampling loop with repetition penalty, top-k, top-p (nucleus), and no-repeat-ngram blocking. Returns the list of newly generated token ids (does not include the prompt). """ block_size = model.config.block_size context = list(prompt_ids) generated: list[int] = [] for _ in range(max_new_tokens): idx_cond = context if len(context) <= block_size else context[-block_size:] x = torch.tensor([idx_cond], dtype=torch.long, device=device) with torch.no_grad(): logits, _ = model(x) logits = logits[:, -1, :].squeeze(0).float() if repetition_penalty != 1.0 and len(context) > 0: seen = torch.tensor(list(set(context)), dtype=torch.long, device=device) scores = logits[seen] scores = torch.where(scores > 0, scores / repetition_penalty, scores * repetition_penalty) logits[seen] = scores if no_repeat_ngram_size > 0 and len(generated) >= no_repeat_ngram_size - 1: banned = banned_tokens_from_ngrams(generated, no_repeat_ngram_size) for tok_id in banned: logits[tok_id] = -float("inf") logits = logits / max(temperature, 1e-5) if top_k is not None and top_k > 0: k = min(top_k, logits.size(-1)) top_vals, _ = torch.topk(logits, k) cutoff = top_vals[-1] logits = torch.where(logits < cutoff, torch.full_like(logits, -float("inf")), logits) if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) probs_sorted = F.softmax(sorted_logits, dim=-1) cumulative = torch.cumsum(probs_sorted, dim=-1) mask = cumulative > top_p mask[..., 1:] = mask[..., :-1].clone() mask[..., 0] = False sorted_logits = sorted_logits.masked_fill(mask, -float("inf")) logits = torch.full_like(logits, -float("inf")) logits.scatter_(0, sorted_idx, sorted_logits) probs = F.softmax(logits, dim=-1) if not torch.isfinite(probs).all() or probs.sum() <= 0: next_tok = int(logits.argmax().item()) else: next_tok = int(torch.multinomial(probs, num_samples=1).item()) if eos_id is not None and next_tok == eos_id: break context.append(next_tok) generated.append(next_tok) return generated def main() -> None: args = parse_args() if args.seed: torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device tok = TextTokenizer(args.tokenizer) ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False) config = GPTConfig(**ckpt["config"]) config.dropout = 0.0 config.gradient_checkpointing = False model = GPT(config) state_dict = strip_compile_prefix(ckpt["model"]) model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() if tok.vocab_size != model.config.vocab_size: raise RuntimeError( f"Tokenizer vocab_size {tok.vocab_size} != model vocab_size {model.config.vocab_size}. " "Use the same tokenizer.json that was used for pretrain/SFT." ) prompt = args.prompt_template.format(passage=args.text.strip()) prompt_ids = tok.encode(prompt, add_bos=True, add_eos=False) max_prompt_len = model.config.block_size - args.max_new_tokens - 1 if max_prompt_len < 16: raise RuntimeError( f"max_new_tokens={args.max_new_tokens} is too large for block_size={model.config.block_size}." ) if len(prompt_ids) > max_prompt_len: bos = [prompt_ids[0]] if prompt_ids and prompt_ids[0] == tok.bos_id else [] tail = prompt_ids[-(max_prompt_len - len(bos)) :] prompt_ids = bos + tail new_ids = generate( model=model, prompt_ids=prompt_ids, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, no_repeat_ngram_size=args.no_repeat_ngram_size, eos_id=tok.eos_id, device=device, ) answer = tok.decode(new_ids, skip_special_tokens=True).strip() print(answer) if __name__ == "__main__": main()