| |
| """ |
| inference_ar.py — Autoregressive left-to-right sampling from a trained ARModel. |
| |
| Companion to train_ar.py. Standard next-token sampling: given a prompt (or just |
| BOS), run the model on the current prefix and sample the next token until we |
| hit max_new_tokens or EOS. |
| |
| Supports: |
| - greedy (temperature=0) |
| - temperature |
| - top-k |
| - top-p (nucleus) |
| |
| Usage: |
| python scripts/inference_ar.py \\ |
| --config configs/ar_owt.yaml \\ |
| --checkpoint outputs/ar_baseline/latest.pt \\ |
| --num_samples 4 \\ |
| --max_new_tokens 256 |
| |
| # conditional (prompt from the training stream) |
| python scripts/inference_ar.py --config ... --checkpoint ... \\ |
| --mode conditional --prompt_len 32 --num_samples 4 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| import yaml |
|
|
| sys.path.insert(0, str(ROOT)) |
|
|
| from src.models.ar_model import ARModel |
| from src.data import build_owt_dataloader |
|
|
|
|
| |
| |
| |
|
|
| def _apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor: |
| """Zero out everything below the top-k largest logits (per row).""" |
| if top_k <= 0 or top_k >= logits.size(-1): |
| return logits |
| topk_vals, _ = logits.topk(top_k, dim=-1) |
| threshold = topk_vals[..., -1:].expand_as(logits) |
| return torch.where(logits < threshold, torch.full_like(logits, float("-inf")), logits) |
|
|
|
|
| def _apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor: |
| """Keep the smallest prefix of sorted probabilities whose sum ≥ top_p.""" |
| if top_p >= 1.0 or top_p <= 0.0: |
| return logits |
| sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) |
| probs = sorted_logits.softmax(dim=-1) |
| cum = probs.cumsum(dim=-1) |
| |
| |
| remove = cum - probs > top_p |
| sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) |
| |
| out = torch.empty_like(sorted_logits) |
| out.scatter_(-1, sorted_idx, sorted_logits) |
| return out |
|
|
|
|
| def _sample_next( |
| logits: torch.Tensor, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| ) -> torch.Tensor: |
| """Sample one token per row from the last-step logits [B, V].""" |
| if temperature <= 0.0: |
| return logits.argmax(dim=-1) |
| logits = logits / temperature |
| logits = _apply_top_k(logits, top_k) |
| logits = _apply_top_p(logits, top_p) |
| probs = logits.softmax(dim=-1) |
| return torch.multinomial(probs, num_samples=1).squeeze(-1) |
|
|
|
|
| |
| |
| |
|
|
| class ARSampler: |
| """ |
| Plain left-to-right AR sampler with KV cache. |
| |
| Pass 1 (prompt): one `forward_cached` call over the full prompt builds the |
| initial KV cache of length P. |
| Pass 2..: single-token `forward_cached` calls that append one (k, v) slice |
| per layer per step. Total length capped at `max_seq_len` (512). |
| """ |
|
|
| def __init__( |
| self, |
| model: ARModel, |
| tokenizer, |
| device: torch.device, |
| dtype: torch.dtype = torch.bfloat16, |
| ): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = device |
| self.dtype = dtype |
| self.max_seq_len: int = model.max_seq_len |
| self.vocab_size: int = model.vocab_size |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| prompt_ids: torch.Tensor, |
| max_new_tokens: int, |
| temperature: float = 1.0, |
| top_k: int = 0, |
| top_p: float = 1.0, |
| eos_token_id: Optional[int] = None, |
| stop_on_eos: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| prompt_ids: [B, P] int64 |
| max_new_tokens: number of new tokens to append |
| Returns: |
| full sequence [B, P + k] where k ≤ max_new_tokens. |
| """ |
| device = self.device |
| seq = prompt_ids.to(device=device, dtype=torch.long) |
| B, P = seq.shape |
| assert P < self.max_seq_len, f"prompt length {P} >= max_seq_len {self.max_seq_len}" |
|
|
| autocast_device = "cuda" if device.type == "cuda" else "cpu" |
| done = torch.zeros(B, dtype=torch.bool, device=device) |
|
|
| |
| with torch.autocast(device_type=autocast_device, dtype=self.dtype): |
| logits, past_kv = self.model.forward_cached(seq, past_kv_list=None) |
| next_logits = logits[:, -1, :].float() |
| next_tok = _sample_next(next_logits, temperature, top_k, top_p) |
| if eos_token_id is not None: |
| done = done | (next_tok == eos_token_id) |
| seq = torch.cat([seq, next_tok.unsqueeze(-1)], dim=1) |
|
|
| |
| for _ in range(max_new_tokens - 1): |
| if stop_on_eos and eos_token_id is not None and done.all(): |
| break |
| if seq.size(1) >= self.max_seq_len: |
| break |
|
|
| with torch.autocast(device_type=autocast_device, dtype=self.dtype): |
| logits, past_kv = self.model.forward_cached( |
| seq[:, -1:], past_kv_list=past_kv, |
| ) |
| next_logits = logits[:, -1, :].float() |
| next_tok = _sample_next(next_logits, temperature, top_k, top_p) |
|
|
| |
| if eos_token_id is not None: |
| next_tok = torch.where( |
| done, torch.full_like(next_tok, eos_token_id), next_tok, |
| ) |
| done = done | (next_tok == eos_token_id) |
|
|
| seq = torch.cat([seq, next_tok.unsqueeze(-1)], dim=1) |
|
|
| return seq.cpu() |
|
|
|
|
| |
| |
| |
|
|
| def _unwrap(model): |
| while True: |
| if hasattr(model, "_orig_mod"): |
| model = model._orig_mod |
| elif hasattr(model, "module"): |
| model = model.module |
| else: |
| return model |
|
|
|
|
| def load_config(path: str) -> dict: |
| with open(path) as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def build_tokenizer(config: dict): |
| from transformers import AutoTokenizer |
| tok = AutoTokenizer.from_pretrained( |
| ROOT / "tokenizers" / "gpt2", |
| local_files_only=True, |
| ) |
| if tok.eos_token is None: |
| tok.add_special_tokens({"eos_token": "<|endoftext|>"}) |
| if tok.bos_token is None: |
| tok.bos_token = tok.eos_token |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| config["model"]["vocab_size"] = len(tok) |
| return tok |
|
|
|
|
| def build_model(config: dict, device: torch.device) -> ARModel: |
| mc = config["model"] |
| return ARModel( |
| vocab_size=mc["vocab_size"], |
| hidden_size=mc["hidden_size"], |
| n_blocks=mc["n_blocks"], |
| n_heads=mc["n_heads"], |
| max_seq_len=mc["max_seq_len"], |
| dropout=mc.get("dropout", 0.0), |
| ).to(device) |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--checkpoint", type=str, required=True) |
| p.add_argument("--config", type=str, default="configs/ar_owt.yaml") |
| p.add_argument("--num_samples", type=int, default=1) |
| p.add_argument("--max_new_tokens", type=int, default=256) |
| p.add_argument("--temperature", type=float, default=1.0) |
| p.add_argument("--top_k", type=int, default=0, |
| help="0 = disabled") |
| p.add_argument("--top_p", type=float, default=1.0, |
| help="1.0 = disabled") |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--device", type=str, |
| default="cuda" if torch.cuda.is_available() else "cpu") |
| p.add_argument("--dtype", type=str, default="bf16", |
| choices=["bf16", "fp16", "fp32"]) |
| p.add_argument("--mode", type=str, default="unconditional", |
| choices=["unconditional", "conditional"], |
| help="unconditional: start from BOS only. " |
| "conditional: take a prefix from the training stream.") |
| p.add_argument("--prompt_len", type=int, default=32, |
| help="(conditional) number of leading tokens drawn from data.") |
| p.add_argument("--prompt_text", type=str, default=None, |
| help="(optional) override prompt with a user-provided string. " |
| "Encoded with the GPT-2 tokenizer.") |
| p.add_argument("--data_seed", type=int, default=0, |
| help="(conditional) seed for shuffling the training split.") |
| p.add_argument("--no_stop_on_eos", action="store_true", |
| help="Disable early-stop on EOS; always emit max_new_tokens.") |
| return p.parse_args() |
|
|
|
|
| def resolve_dtype(name: str) -> torch.dtype: |
| return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name] |
|
|
|
|
| def _build_prompt_ids(args, config, tokenizer, device) -> torch.Tensor: |
| """Returns [B, P] int64 prompt tensor on `device`.""" |
| bos = tokenizer.bos_token_id |
| assert bos is not None, "tokenizer has no bos_token_id" |
|
|
| if args.prompt_text is not None: |
| ids = tokenizer(args.prompt_text, return_tensors="pt")["input_ids"][0].tolist() |
| if not ids or ids[0] != bos: |
| ids = [bos] + ids |
| prompt = torch.tensor(ids, dtype=torch.long, device=device) |
| return prompt.unsqueeze(0).expand(args.num_samples, -1).contiguous() |
|
|
| if args.mode == "unconditional": |
| return torch.full((args.num_samples, 1), bos, dtype=torch.long, device=device) |
|
|
| |
| |
| data_cfg = config.get("data", {}) |
| seq_len = config["model"]["max_seq_len"] |
| cache_dir = data_cfg.get("cache_dir", None) |
| if cache_dir is not None and not Path(cache_dir).is_absolute(): |
| repo_root = ROOT |
| candidate = repo_root / cache_dir |
| if candidate.exists(): |
| cache_dir = str(candidate) |
| loader = build_owt_dataloader( |
| tokenizer, |
| split="train[:-100000]", |
| seq_len=seq_len, |
| batch_size=args.num_samples, |
| num_workers=0, |
| cache_dir=cache_dir, |
| seed=args.data_seed, |
| mode=data_cfg.get("mode", "subsample"), |
| shard_across_ranks=False, |
| ) |
| batch = next(iter(loader)) |
| return batch["input_ids"][:args.num_samples, :args.prompt_len].to(device) |
|
|
|
|
| def main(): |
| args = parse_args() |
| torch.manual_seed(args.seed) |
|
|
| device = torch.device(args.device) |
| dtype = resolve_dtype(args.dtype) |
|
|
| config = load_config(args.config) |
| tokenizer = build_tokenizer(config) |
|
|
| model = build_model(config, device).to(dtype) |
| ckpt = torch.load(args.checkpoint, map_location=device) |
| raw_state = ckpt.get("model", ckpt) |
| _unwrap(model).load_state_dict(raw_state, strict=False) |
| model.eval() |
| print(f"Loaded checkpoint: {args.checkpoint} (step={ckpt.get('step', '?')})") |
|
|
| prompt_ids = _build_prompt_ids(args, config, tokenizer, device) |
| P = prompt_ids.size(1) |
| print(f"Sampling {args.num_samples} sequences ({args.mode}) " |
| f"prompt_len={P} max_new_tokens={args.max_new_tokens} " |
| f"T={args.temperature} top_k={args.top_k} top_p={args.top_p}") |
|
|
| sampler = ARSampler( |
| model=_unwrap(model), |
| tokenizer=tokenizer, |
| device=device, |
| dtype=dtype, |
| ) |
| out_ids = sampler.generate( |
| prompt_ids=prompt_ids, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| eos_token_id=tokenizer.eos_token_id, |
| stop_on_eos=not args.no_stop_on_eos, |
| ) |
|
|
| print("\n" + "=" * 72) |
| for i, ids in enumerate(out_ids): |
| ids_list = ids.tolist() |
| print(f"[Sample {i + 1}]") |
| prompt_text = tokenizer.decode(ids_list[:P], skip_special_tokens=True) |
| gen_text = tokenizer.decode(ids_list[P:], skip_special_tokens=True) |
| if P > 1 or args.prompt_text is not None: |
| print(f"<prompt ({P} tok)> {prompt_text}") |
| print(f"<generated> {gen_text}") |
| else: |
| print(tokenizer.decode(ids_list, skip_special_tokens=True)) |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|