sad / scripts /inference_ar.py
haochengsama's picture
Add files using upload-large-folder tool
8b0aeb2 verified
Raw
History Blame Contribute Delete
13.8 kB
#!/usr/bin/env python3
"""
inference_ar.py — Autoregressive left-to-right sampling from a trained ARModel.
Companion to train_ar.py. Standard next-token sampling: given a prompt (or just
BOS), run the model on the current prefix and sample the next token until we
hit max_new_tokens or EOS.
Supports:
- greedy (temperature=0)
- temperature
- top-k
- top-p (nucleus)
Usage:
python scripts/inference_ar.py \\
--config configs/ar_owt.yaml \\
--checkpoint outputs/ar_baseline/latest.pt \\
--num_samples 4 \\
--max_new_tokens 256
# conditional (prompt from the training stream)
python scripts/inference_ar.py --config ... --checkpoint ... \\
--mode conditional --prompt_len 32 --num_samples 4
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1] # sad/
from typing import Optional
import torch
import torch.nn.functional as F
import yaml
sys.path.insert(0, str(ROOT))
from src.models.ar_model import ARModel
from src.data import build_owt_dataloader
# ─────────────────────────────────────────────────────────────────────────────
# Sampling helpers
# ─────────────────────────────────────────────────────────────────────────────
def _apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor:
"""Zero out everything below the top-k largest logits (per row)."""
if top_k <= 0 or top_k >= logits.size(-1):
return logits
topk_vals, _ = logits.topk(top_k, dim=-1)
threshold = topk_vals[..., -1:].expand_as(logits)
return torch.where(logits < threshold, torch.full_like(logits, float("-inf")), logits)
def _apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
"""Keep the smallest prefix of sorted probabilities whose sum ≥ top_p."""
if top_p >= 1.0 or top_p <= 0.0:
return logits
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
probs = sorted_logits.softmax(dim=-1)
cum = probs.cumsum(dim=-1)
# Tokens whose cumulative prob is already past top_p (excluding the first
# crossing token itself) are removed.
remove = cum - probs > top_p
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
# Scatter back to original vocab order.
out = torch.empty_like(sorted_logits)
out.scatter_(-1, sorted_idx, sorted_logits)
return out
def _sample_next(
logits: torch.Tensor,
temperature: float,
top_k: int,
top_p: float,
) -> torch.Tensor:
"""Sample one token per row from the last-step logits [B, V]."""
if temperature <= 0.0:
return logits.argmax(dim=-1)
logits = logits / temperature
logits = _apply_top_k(logits, top_k)
logits = _apply_top_p(logits, top_p)
probs = logits.softmax(dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
# ─────────────────────────────────────────────────────────────────────────────
# Sampler
# ─────────────────────────────────────────────────────────────────────────────
class ARSampler:
"""
Plain left-to-right AR sampler with KV cache.
Pass 1 (prompt): one `forward_cached` call over the full prompt builds the
initial KV cache of length P.
Pass 2..: single-token `forward_cached` calls that append one (k, v) slice
per layer per step. Total length capped at `max_seq_len` (512).
"""
def __init__(
self,
model: ARModel,
tokenizer,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.dtype = dtype
self.max_seq_len: int = model.max_seq_len
self.vocab_size: int = model.vocab_size
@torch.no_grad()
def generate(
self,
prompt_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
eos_token_id: Optional[int] = None,
stop_on_eos: bool = True,
) -> torch.Tensor:
"""
Args:
prompt_ids: [B, P] int64
max_new_tokens: number of new tokens to append
Returns:
full sequence [B, P + k] where k ≤ max_new_tokens.
"""
device = self.device
seq = prompt_ids.to(device=device, dtype=torch.long)
B, P = seq.shape
assert P < self.max_seq_len, f"prompt length {P} >= max_seq_len {self.max_seq_len}"
autocast_device = "cuda" if device.type == "cuda" else "cpu"
done = torch.zeros(B, dtype=torch.bool, device=device)
# ── Pass 1: consume the prompt, build the initial KV cache ─────────
with torch.autocast(device_type=autocast_device, dtype=self.dtype):
logits, past_kv = self.model.forward_cached(seq, past_kv_list=None)
next_logits = logits[:, -1, :].float() # [B, V]
next_tok = _sample_next(next_logits, temperature, top_k, top_p) # [B]
if eos_token_id is not None:
done = done | (next_tok == eos_token_id)
seq = torch.cat([seq, next_tok.unsqueeze(-1)], dim=1)
# ── Pass 2..: single-token appends using the growing KV cache ──────
for _ in range(max_new_tokens - 1):
if stop_on_eos and eos_token_id is not None and done.all():
break
if seq.size(1) >= self.max_seq_len:
break
with torch.autocast(device_type=autocast_device, dtype=self.dtype):
logits, past_kv = self.model.forward_cached(
seq[:, -1:], past_kv_list=past_kv,
)
next_logits = logits[:, -1, :].float()
next_tok = _sample_next(next_logits, temperature, top_k, top_p)
# Frozen rows keep emitting EOS so the batch stays rectangular.
if eos_token_id is not None:
next_tok = torch.where(
done, torch.full_like(next_tok, eos_token_id), next_tok,
)
done = done | (next_tok == eos_token_id)
seq = torch.cat([seq, next_tok.unsqueeze(-1)], dim=1)
return seq.cpu()
# ─────────────────────────────────────────────────────────────────────────────
# Plumbing
# ─────────────────────────────────────────────────────────────────────────────
def _unwrap(model):
while True:
if hasattr(model, "_orig_mod"):
model = model._orig_mod
elif hasattr(model, "module"):
model = model.module
else:
return model
def load_config(path: str) -> dict:
with open(path) as f:
return yaml.safe_load(f)
def build_tokenizer(config: dict):
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
ROOT / "tokenizers" / "gpt2",
local_files_only=True,
)
if tok.eos_token is None:
tok.add_special_tokens({"eos_token": "<|endoftext|>"})
if tok.bos_token is None:
tok.bos_token = tok.eos_token
if tok.pad_token is None:
tok.pad_token = tok.eos_token
config["model"]["vocab_size"] = len(tok)
return tok
def build_model(config: dict, device: torch.device) -> ARModel:
mc = config["model"]
return ARModel(
vocab_size=mc["vocab_size"],
hidden_size=mc["hidden_size"],
n_blocks=mc["n_blocks"],
n_heads=mc["n_heads"],
max_seq_len=mc["max_seq_len"],
dropout=mc.get("dropout", 0.0),
).to(device)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--checkpoint", type=str, required=True)
p.add_argument("--config", type=str, default="configs/ar_owt.yaml")
p.add_argument("--num_samples", type=int, default=1)
p.add_argument("--max_new_tokens", type=int, default=256)
p.add_argument("--temperature", type=float, default=1.0)
p.add_argument("--top_k", type=int, default=0,
help="0 = disabled")
p.add_argument("--top_p", type=float, default=1.0,
help="1.0 = disabled")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--dtype", type=str, default="bf16",
choices=["bf16", "fp16", "fp32"])
p.add_argument("--mode", type=str, default="unconditional",
choices=["unconditional", "conditional"],
help="unconditional: start from BOS only. "
"conditional: take a prefix from the training stream.")
p.add_argument("--prompt_len", type=int, default=32,
help="(conditional) number of leading tokens drawn from data.")
p.add_argument("--prompt_text", type=str, default=None,
help="(optional) override prompt with a user-provided string. "
"Encoded with the GPT-2 tokenizer.")
p.add_argument("--data_seed", type=int, default=0,
help="(conditional) seed for shuffling the training split.")
p.add_argument("--no_stop_on_eos", action="store_true",
help="Disable early-stop on EOS; always emit max_new_tokens.")
return p.parse_args()
def resolve_dtype(name: str) -> torch.dtype:
return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name]
def _build_prompt_ids(args, config, tokenizer, device) -> torch.Tensor:
"""Returns [B, P] int64 prompt tensor on `device`."""
bos = tokenizer.bos_token_id
assert bos is not None, "tokenizer has no bos_token_id"
if args.prompt_text is not None:
ids = tokenizer(args.prompt_text, return_tensors="pt")["input_ids"][0].tolist()
if not ids or ids[0] != bos:
ids = [bos] + ids
prompt = torch.tensor(ids, dtype=torch.long, device=device)
return prompt.unsqueeze(0).expand(args.num_samples, -1).contiguous()
if args.mode == "unconditional":
return torch.full((args.num_samples, 1), bos, dtype=torch.long, device=device)
# conditional: pull a batch from the OWT train split, take the first
# `prompt_len` tokens (already [BOS]-prefixed by the dataloader).
data_cfg = config.get("data", {})
seq_len = config["model"]["max_seq_len"]
cache_dir = data_cfg.get("cache_dir", None)
if cache_dir is not None and not Path(cache_dir).is_absolute():
repo_root = ROOT
candidate = repo_root / cache_dir
if candidate.exists():
cache_dir = str(candidate)
loader = build_owt_dataloader(
tokenizer,
split="train[:-100000]",
seq_len=seq_len,
batch_size=args.num_samples,
num_workers=0,
cache_dir=cache_dir,
seed=args.data_seed,
mode=data_cfg.get("mode", "subsample"),
shard_across_ranks=False,
)
batch = next(iter(loader))
return batch["input_ids"][:args.num_samples, :args.prompt_len].to(device)
def main():
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device(args.device)
dtype = resolve_dtype(args.dtype)
config = load_config(args.config)
tokenizer = build_tokenizer(config)
model = build_model(config, device).to(dtype)
ckpt = torch.load(args.checkpoint, map_location=device)
raw_state = ckpt.get("model", ckpt)
_unwrap(model).load_state_dict(raw_state, strict=False)
model.eval()
print(f"Loaded checkpoint: {args.checkpoint} (step={ckpt.get('step', '?')})")
prompt_ids = _build_prompt_ids(args, config, tokenizer, device)
P = prompt_ids.size(1)
print(f"Sampling {args.num_samples} sequences ({args.mode}) "
f"prompt_len={P} max_new_tokens={args.max_new_tokens} "
f"T={args.temperature} top_k={args.top_k} top_p={args.top_p}")
sampler = ARSampler(
model=_unwrap(model),
tokenizer=tokenizer,
device=device,
dtype=dtype,
)
out_ids = sampler.generate(
prompt_ids=prompt_ids,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
eos_token_id=tokenizer.eos_token_id,
stop_on_eos=not args.no_stop_on_eos,
) # [B, P + k]
print("\n" + "=" * 72)
for i, ids in enumerate(out_ids):
ids_list = ids.tolist()
print(f"[Sample {i + 1}]")
prompt_text = tokenizer.decode(ids_list[:P], skip_special_tokens=True)
gen_text = tokenizer.decode(ids_list[P:], skip_special_tokens=True)
if P > 1 or args.prompt_text is not None:
print(f"<prompt ({P} tok)> {prompt_text}")
print(f"<generated> {gen_text}")
else:
print(tokenizer.decode(ids_list, skip_special_tokens=True))
print()
if __name__ == "__main__":
main()