File size: 4,336 Bytes
a19b01b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """
Inference script for the 1B Transformer — Single GPU.
Usage:
python inference.py # auto-finds latest checkpoint
python inference.py /path/to/checkpoint.pt # specific checkpoint
"""
import sys
import os
import glob
import time
import torch
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer
def find_latest_checkpoint(checkpoint_dir="/jfs/deepak-kumar/checkpoints"):
files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
if not files:
final = os.path.join(checkpoint_dir, "final.pt")
return final if os.path.exists(final) else None
return max(files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0]))
def load_model(checkpoint_path, device="cuda:0"):
config = ModelConfig()
model = Transformer(config)
print(f"Loading checkpoint: {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model"])
model = model.to(device).bfloat16().eval()
step = ckpt.get("step", "?")
loss = ckpt.get("loss", "?")
print(f" Step: {step} | Loss: {loss}")
print(f" Params: {sum(p.numel() for p in model.parameters()):,}")
print(f" Device: {device}")
del ckpt
torch.cuda.empty_cache()
return model, config
@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens=200,
temperature=0.8, top_k=50, top_p=0.9, device="cuda:0"):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
t0 = time.time()
for i in range(max_new_tokens):
if input_ids.shape[1] >= model.config.max_seq_len:
break
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits, _ = model(input_ids)
logits = logits[:, -1, :] / temperature
if top_k > 0:
topk_vals, _ = torch.topk(logits, top_k)
logits[logits < topk_vals[:, -1:]] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[mask] = float("-inf")
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == tokenizer.eos_token_id:
break
input_ids = torch.cat([input_ids, next_token], dim=1)
elapsed = time.time() - t0
gen_tokens = input_ids.shape[1] - len(tokenizer.encode(prompt))
tok_per_sec = gen_tokens / max(elapsed, 1e-9)
text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
return text, gen_tokens, tok_per_sec
def main():
device = "cuda:0"
if len(sys.argv) > 1:
checkpoint = sys.argv[1]
else:
checkpoint = find_latest_checkpoint()
if checkpoint is None:
print("No checkpoint found!")
sys.exit(1)
model, config = load_model(checkpoint, device)
tokenizer = get_tokenizer()
prompts = [
"The meaning of life is",
"In machine learning, a neural network",
"The capital of France is",
"Once upon a time, there was a",
"To solve a quadratic equation, you need to",
"The theory of relativity explains that",
"Python is a programming language that",
"The sun rises in the east and",
]
print("\n" + "=" * 70)
print(" INFERENCE — 1B Transformer (Single GPU)")
print("=" * 70)
for prompt in prompts:
print(f"\n{'─' * 60}")
print(f"PROMPT: {prompt}")
print(f"{'─' * 60}")
text, n_tok, tps = generate(model, tokenizer, prompt,
max_new_tokens=150, temperature=0.8,
top_k=50, device=device)
generated = text[len(prompt):]
print(f"OUTPUT:{generated}")
print(f" [{n_tok} tokens, {tps:.1f} tok/s]")
print("\n" + "=" * 70)
if __name__ == "__main__":
main()
|