""" Inference script for the 1B Transformer — Single GPU. Usage: python inference.py # auto-finds latest checkpoint python inference.py /path/to/checkpoint.pt # specific checkpoint """ import sys import os import glob import time import torch import torch.nn.functional as F sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from model.config import ModelConfig from model.transformer import Transformer from model.data import get_tokenizer def find_latest_checkpoint(checkpoint_dir="/jfs/deepak-kumar/checkpoints"): files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt")) if not files: final = os.path.join(checkpoint_dir, "final.pt") return final if os.path.exists(final) else None return max(files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])) def load_model(checkpoint_path, device="cuda:0"): config = ModelConfig() model = Transformer(config) print(f"Loading checkpoint: {checkpoint_path}") ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) model.load_state_dict(ckpt["model"]) model = model.to(device).bfloat16().eval() step = ckpt.get("step", "?") loss = ckpt.get("loss", "?") print(f" Step: {step} | Loss: {loss}") print(f" Params: {sum(p.numel() for p in model.parameters()):,}") print(f" Device: {device}") del ckpt torch.cuda.empty_cache() return model, config @torch.no_grad() def generate(model, tokenizer, prompt, max_new_tokens=200, temperature=0.8, top_k=50, top_p=0.9, device="cuda:0"): input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) t0 = time.time() for i in range(max_new_tokens): if input_ids.shape[1] >= model.config.max_seq_len: break with torch.autocast(device_type="cuda", dtype=torch.bfloat16): logits, _ = model(input_ids) logits = logits[:, -1, :] / temperature if top_k > 0: topk_vals, _ = torch.topk(logits, top_k) logits[logits < topk_vals[:, -1:]] = float("-inf") if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p sorted_logits[mask] = float("-inf") logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) if next_token.item() == tokenizer.eos_token_id: break input_ids = torch.cat([input_ids, next_token], dim=1) elapsed = time.time() - t0 gen_tokens = input_ids.shape[1] - len(tokenizer.encode(prompt)) tok_per_sec = gen_tokens / max(elapsed, 1e-9) text = tokenizer.decode(input_ids[0], skip_special_tokens=True) return text, gen_tokens, tok_per_sec def main(): device = "cuda:0" if len(sys.argv) > 1: checkpoint = sys.argv[1] else: checkpoint = find_latest_checkpoint() if checkpoint is None: print("No checkpoint found!") sys.exit(1) model, config = load_model(checkpoint, device) tokenizer = get_tokenizer() prompts = [ "The meaning of life is", "In machine learning, a neural network", "The capital of France is", "Once upon a time, there was a", "To solve a quadratic equation, you need to", "The theory of relativity explains that", "Python is a programming language that", "The sun rises in the east and", ] print("\n" + "=" * 70) print(" INFERENCE — 1B Transformer (Single GPU)") print("=" * 70) for prompt in prompts: print(f"\n{'─' * 60}") print(f"PROMPT: {prompt}") print(f"{'─' * 60}") text, n_tok, tps = generate(model, tokenizer, prompt, max_new_tokens=150, temperature=0.8, top_k=50, device=device) generated = text[len(prompt):] print(f"OUTPUT:{generated}") print(f" [{n_tok} tokens, {tps:.1f} tok/s]") print("\n" + "=" * 70) if __name__ == "__main__": main()