| | """ |
| | Inference script for the 1B Transformer — Single GPU. |
| | |
| | Usage: |
| | python inference.py # auto-finds latest checkpoint |
| | python inference.py /path/to/checkpoint.pt # specific checkpoint |
| | """ |
| |
|
| | import sys |
| | import os |
| | import glob |
| | import time |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from model.config import ModelConfig |
| | from model.transformer import Transformer |
| | from model.data import get_tokenizer |
| |
|
| |
|
| | def find_latest_checkpoint(checkpoint_dir="/jfs/deepak-kumar/checkpoints"): |
| | files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt")) |
| | if not files: |
| | final = os.path.join(checkpoint_dir, "final.pt") |
| | return final if os.path.exists(final) else None |
| | return max(files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])) |
| |
|
| |
|
| | def load_model(checkpoint_path, device="cuda:0"): |
| | config = ModelConfig() |
| | model = Transformer(config) |
| |
|
| | print(f"Loading checkpoint: {checkpoint_path}") |
| | ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| | model.load_state_dict(ckpt["model"]) |
| | model = model.to(device).bfloat16().eval() |
| |
|
| | step = ckpt.get("step", "?") |
| | loss = ckpt.get("loss", "?") |
| | print(f" Step: {step} | Loss: {loss}") |
| | print(f" Params: {sum(p.numel() for p in model.parameters()):,}") |
| | print(f" Device: {device}") |
| | del ckpt |
| | torch.cuda.empty_cache() |
| | return model, config |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate(model, tokenizer, prompt, max_new_tokens=200, |
| | temperature=0.8, top_k=50, top_p=0.9, device="cuda:0"): |
| | input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
| | t0 = time.time() |
| |
|
| | for i in range(max_new_tokens): |
| | if input_ids.shape[1] >= model.config.max_seq_len: |
| | break |
| |
|
| | with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| | logits, _ = model(input_ids) |
| |
|
| | logits = logits[:, -1, :] / temperature |
| |
|
| | if top_k > 0: |
| | topk_vals, _ = torch.topk(logits, top_k) |
| | logits[logits < topk_vals[:, -1:]] = float("-inf") |
| |
|
| | if top_p < 1.0: |
| | sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| | sorted_logits[mask] = float("-inf") |
| | logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) |
| |
|
| | probs = F.softmax(logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| |
|
| | if next_token.item() == tokenizer.eos_token_id: |
| | break |
| |
|
| | input_ids = torch.cat([input_ids, next_token], dim=1) |
| |
|
| | elapsed = time.time() - t0 |
| | gen_tokens = input_ids.shape[1] - len(tokenizer.encode(prompt)) |
| | tok_per_sec = gen_tokens / max(elapsed, 1e-9) |
| |
|
| | text = tokenizer.decode(input_ids[0], skip_special_tokens=True) |
| | return text, gen_tokens, tok_per_sec |
| |
|
| |
|
| | def main(): |
| | device = "cuda:0" |
| | if len(sys.argv) > 1: |
| | checkpoint = sys.argv[1] |
| | else: |
| | checkpoint = find_latest_checkpoint() |
| | if checkpoint is None: |
| | print("No checkpoint found!") |
| | sys.exit(1) |
| |
|
| | model, config = load_model(checkpoint, device) |
| | tokenizer = get_tokenizer() |
| |
|
| | prompts = [ |
| | "The meaning of life is", |
| | "In machine learning, a neural network", |
| | "The capital of France is", |
| | "Once upon a time, there was a", |
| | "To solve a quadratic equation, you need to", |
| | "The theory of relativity explains that", |
| | "Python is a programming language that", |
| | "The sun rises in the east and", |
| | ] |
| |
|
| | print("\n" + "=" * 70) |
| | print(" INFERENCE — 1B Transformer (Single GPU)") |
| | print("=" * 70) |
| |
|
| | for prompt in prompts: |
| | print(f"\n{'─' * 60}") |
| | print(f"PROMPT: {prompt}") |
| | print(f"{'─' * 60}") |
| | text, n_tok, tps = generate(model, tokenizer, prompt, |
| | max_new_tokens=150, temperature=0.8, |
| | top_k=50, device=device) |
| | generated = text[len(prompt):] |
| | print(f"OUTPUT:{generated}") |
| | print(f" [{n_tok} tokens, {tps:.1f} tok/s]") |
| |
|
| | print("\n" + "=" * 70) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|