#!/usr/bin/env python3 """ inference_ar.py — Autoregressive left-to-right sampling from a trained ARModel. Companion to train_ar.py. Standard next-token sampling: given a prompt (or just BOS), run the model on the current prefix and sample the next token until we hit max_new_tokens or EOS. Supports: - greedy (temperature=0) - temperature - top-k - top-p (nucleus) Usage: python scripts/inference_ar.py \\ --config configs/ar_owt.yaml \\ --checkpoint outputs/ar_baseline/latest.pt \\ --num_samples 4 \\ --max_new_tokens 256 # conditional (prompt from the training stream) python scripts/inference_ar.py --config ... --checkpoint ... \\ --mode conditional --prompt_len 32 --num_samples 4 """ from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] # sad/ from typing import Optional import torch import torch.nn.functional as F import yaml sys.path.insert(0, str(ROOT)) from src.models.ar_model import ARModel from src.data import build_owt_dataloader # ───────────────────────────────────────────────────────────────────────────── # Sampling helpers # ───────────────────────────────────────────────────────────────────────────── def _apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor: """Zero out everything below the top-k largest logits (per row).""" if top_k <= 0 or top_k >= logits.size(-1): return logits topk_vals, _ = logits.topk(top_k, dim=-1) threshold = topk_vals[..., -1:].expand_as(logits) return torch.where(logits < threshold, torch.full_like(logits, float("-inf")), logits) def _apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor: """Keep the smallest prefix of sorted probabilities whose sum ≥ top_p.""" if top_p >= 1.0 or top_p <= 0.0: return logits sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) probs = sorted_logits.softmax(dim=-1) cum = probs.cumsum(dim=-1) # Tokens whose cumulative prob is already past top_p (excluding the first # crossing token itself) are removed. remove = cum - probs > top_p sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) # Scatter back to original vocab order. out = torch.empty_like(sorted_logits) out.scatter_(-1, sorted_idx, sorted_logits) return out def _sample_next( logits: torch.Tensor, temperature: float, top_k: int, top_p: float, ) -> torch.Tensor: """Sample one token per row from the last-step logits [B, V].""" if temperature <= 0.0: return logits.argmax(dim=-1) logits = logits / temperature logits = _apply_top_k(logits, top_k) logits = _apply_top_p(logits, top_p) probs = logits.softmax(dim=-1) return torch.multinomial(probs, num_samples=1).squeeze(-1) # ───────────────────────────────────────────────────────────────────────────── # Sampler # ───────────────────────────────────────────────────────────────────────────── class ARSampler: """ Plain left-to-right AR sampler with KV cache. Pass 1 (prompt): one `forward_cached` call over the full prompt builds the initial KV cache of length P. Pass 2..: single-token `forward_cached` calls that append one (k, v) slice per layer per step. Total length capped at `max_seq_len` (512). """ def __init__( self, model: ARModel, tokenizer, device: torch.device, dtype: torch.dtype = torch.bfloat16, ): self.model = model self.tokenizer = tokenizer self.device = device self.dtype = dtype self.max_seq_len: int = model.max_seq_len self.vocab_size: int = model.vocab_size @torch.no_grad() def generate( self, prompt_ids: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, eos_token_id: Optional[int] = None, stop_on_eos: bool = True, ) -> torch.Tensor: """ Args: prompt_ids: [B, P] int64 max_new_tokens: number of new tokens to append Returns: full sequence [B, P + k] where k ≤ max_new_tokens. """ device = self.device seq = prompt_ids.to(device=device, dtype=torch.long) B, P = seq.shape assert P < self.max_seq_len, f"prompt length {P} >= max_seq_len {self.max_seq_len}" autocast_device = "cuda" if device.type == "cuda" else "cpu" done = torch.zeros(B, dtype=torch.bool, device=device) # ── Pass 1: consume the prompt, build the initial KV cache ───────── with torch.autocast(device_type=autocast_device, dtype=self.dtype): logits, past_kv = self.model.forward_cached(seq, past_kv_list=None) next_logits = logits[:, -1, :].float() # [B, V] next_tok = _sample_next(next_logits, temperature, top_k, top_p) # [B] if eos_token_id is not None: done = done | (next_tok == eos_token_id) seq = torch.cat([seq, next_tok.unsqueeze(-1)], dim=1) # ── Pass 2..: single-token appends using the growing KV cache ────── for _ in range(max_new_tokens - 1): if stop_on_eos and eos_token_id is not None and done.all(): break if seq.size(1) >= self.max_seq_len: break with torch.autocast(device_type=autocast_device, dtype=self.dtype): logits, past_kv = self.model.forward_cached( seq[:, -1:], past_kv_list=past_kv, ) next_logits = logits[:, -1, :].float() next_tok = _sample_next(next_logits, temperature, top_k, top_p) # Frozen rows keep emitting EOS so the batch stays rectangular. if eos_token_id is not None: next_tok = torch.where( done, torch.full_like(next_tok, eos_token_id), next_tok, ) done = done | (next_tok == eos_token_id) seq = torch.cat([seq, next_tok.unsqueeze(-1)], dim=1) return seq.cpu() # ───────────────────────────────────────────────────────────────────────────── # Plumbing # ───────────────────────────────────────────────────────────────────────────── def _unwrap(model): while True: if hasattr(model, "_orig_mod"): model = model._orig_mod elif hasattr(model, "module"): model = model.module else: return model def load_config(path: str) -> dict: with open(path) as f: return yaml.safe_load(f) def build_tokenizer(config: dict): from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained( ROOT / "tokenizers" / "gpt2", local_files_only=True, ) if tok.eos_token is None: tok.add_special_tokens({"eos_token": "<|endoftext|>"}) if tok.bos_token is None: tok.bos_token = tok.eos_token if tok.pad_token is None: tok.pad_token = tok.eos_token config["model"]["vocab_size"] = len(tok) return tok def build_model(config: dict, device: torch.device) -> ARModel: mc = config["model"] return ARModel( vocab_size=mc["vocab_size"], hidden_size=mc["hidden_size"], n_blocks=mc["n_blocks"], n_heads=mc["n_heads"], max_seq_len=mc["max_seq_len"], dropout=mc.get("dropout", 0.0), ).to(device) def parse_args(): p = argparse.ArgumentParser() p.add_argument("--checkpoint", type=str, required=True) p.add_argument("--config", type=str, default="configs/ar_owt.yaml") p.add_argument("--num_samples", type=int, default=1) p.add_argument("--max_new_tokens", type=int, default=256) p.add_argument("--temperature", type=float, default=1.0) p.add_argument("--top_k", type=int, default=0, help="0 = disabled") p.add_argument("--top_p", type=float, default=1.0, help="1.0 = disabled") p.add_argument("--seed", type=int, default=42) p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) p.add_argument("--mode", type=str, default="unconditional", choices=["unconditional", "conditional"], help="unconditional: start from BOS only. " "conditional: take a prefix from the training stream.") p.add_argument("--prompt_len", type=int, default=32, help="(conditional) number of leading tokens drawn from data.") p.add_argument("--prompt_text", type=str, default=None, help="(optional) override prompt with a user-provided string. " "Encoded with the GPT-2 tokenizer.") p.add_argument("--data_seed", type=int, default=0, help="(conditional) seed for shuffling the training split.") p.add_argument("--no_stop_on_eos", action="store_true", help="Disable early-stop on EOS; always emit max_new_tokens.") return p.parse_args() def resolve_dtype(name: str) -> torch.dtype: return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name] def _build_prompt_ids(args, config, tokenizer, device) -> torch.Tensor: """Returns [B, P] int64 prompt tensor on `device`.""" bos = tokenizer.bos_token_id assert bos is not None, "tokenizer has no bos_token_id" if args.prompt_text is not None: ids = tokenizer(args.prompt_text, return_tensors="pt")["input_ids"][0].tolist() if not ids or ids[0] != bos: ids = [bos] + ids prompt = torch.tensor(ids, dtype=torch.long, device=device) return prompt.unsqueeze(0).expand(args.num_samples, -1).contiguous() if args.mode == "unconditional": return torch.full((args.num_samples, 1), bos, dtype=torch.long, device=device) # conditional: pull a batch from the OWT train split, take the first # `prompt_len` tokens (already [BOS]-prefixed by the dataloader). data_cfg = config.get("data", {}) seq_len = config["model"]["max_seq_len"] cache_dir = data_cfg.get("cache_dir", None) if cache_dir is not None and not Path(cache_dir).is_absolute(): repo_root = ROOT candidate = repo_root / cache_dir if candidate.exists(): cache_dir = str(candidate) loader = build_owt_dataloader( tokenizer, split="train[:-100000]", seq_len=seq_len, batch_size=args.num_samples, num_workers=0, cache_dir=cache_dir, seed=args.data_seed, mode=data_cfg.get("mode", "subsample"), shard_across_ranks=False, ) batch = next(iter(loader)) return batch["input_ids"][:args.num_samples, :args.prompt_len].to(device) def main(): args = parse_args() torch.manual_seed(args.seed) device = torch.device(args.device) dtype = resolve_dtype(args.dtype) config = load_config(args.config) tokenizer = build_tokenizer(config) model = build_model(config, device).to(dtype) ckpt = torch.load(args.checkpoint, map_location=device) raw_state = ckpt.get("model", ckpt) _unwrap(model).load_state_dict(raw_state, strict=False) model.eval() print(f"Loaded checkpoint: {args.checkpoint} (step={ckpt.get('step', '?')})") prompt_ids = _build_prompt_ids(args, config, tokenizer, device) P = prompt_ids.size(1) print(f"Sampling {args.num_samples} sequences ({args.mode}) " f"prompt_len={P} max_new_tokens={args.max_new_tokens} " f"T={args.temperature} top_k={args.top_k} top_p={args.top_p}") sampler = ARSampler( model=_unwrap(model), tokenizer=tokenizer, device=device, dtype=dtype, ) out_ids = sampler.generate( prompt_ids=prompt_ids, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, eos_token_id=tokenizer.eos_token_id, stop_on_eos=not args.no_stop_on_eos, ) # [B, P + k] print("\n" + "=" * 72) for i, ids in enumerate(out_ids): ids_list = ids.tolist() print(f"[Sample {i + 1}]") prompt_text = tokenizer.decode(ids_list[:P], skip_special_tokens=True) gen_text = tokenizer.decode(ids_list[P:], skip_special_tokens=True) if P > 1 or args.prompt_text is not None: print(f" {prompt_text}") print(f" {gen_text}") else: print(tokenizer.decode(ids_list, skip_special_tokens=True)) print() if __name__ == "__main__": main()