simbot-gpt-level1 / inference.py
hranjan043's picture
Update model files and add inference script
85fbfaf verified
import torch
import json
from tokenizers import Tokenizer
from safetensors.torch import load_file
from model.simbot import SIMGPT
# -----------------------------
# Load tokenizer & config
# -----------------------------
tokenizer = Tokenizer.from_file("tokenizer.json")
with open("config.json") as f:
cfg = json.load(f)
# -----------------------------
# Load model
# -----------------------------
model = SIMGPT(
vocab_size=cfg["vocab_size"],
block_size=cfg["block_size"],
n_layers=cfg["n_layers"],
n_heads=cfg["n_heads"],
d_model=cfg["d_model"]
)
state_dict = load_file("simbot.safetensors")
model.load_state_dict(state_dict)
model.eval()
print("SimBot GPT ready. Type 'exit' to quit.\n")
# -----------------------------
# Interactive loop
# -----------------------------
while True:
user_input = input("User: ").strip()
if user_input.lower() in {"exit", "quit"}:
break
prompt = f"<bos>\nUser: {user_input}\nAssistant:"
ids = tokenizer.encode(prompt).ids
x = torch.tensor(ids).unsqueeze(0)
with torch.no_grad():
for _ in range(80):
logits = model(x)
next_id = torch.argmax(logits[:, -1, :], dim=-1).item()
x = torch.cat([x, torch.tensor([[next_id]])], dim=1)
output = tokenizer.decode(x[0].tolist())
print("\nAssistant:", output.split("Assistant:")[-1].strip(), "\n")