import torch import json from tokenizers import Tokenizer from safetensors.torch import load_file from model.simbot import SIMGPT # ----------------------------- # Load tokenizer & config # ----------------------------- tokenizer = Tokenizer.from_file("tokenizer.json") with open("config.json") as f: cfg = json.load(f) # ----------------------------- # Load model # ----------------------------- model = SIMGPT( vocab_size=cfg["vocab_size"], block_size=cfg["block_size"], n_layers=cfg["n_layers"], n_heads=cfg["n_heads"], d_model=cfg["d_model"] ) state_dict = load_file("simbot.safetensors") model.load_state_dict(state_dict) model.eval() print("SimBot GPT ready. Type 'exit' to quit.\n") # ----------------------------- # Interactive loop # ----------------------------- while True: user_input = input("User: ").strip() if user_input.lower() in {"exit", "quit"}: break prompt = f"\nUser: {user_input}\nAssistant:" ids = tokenizer.encode(prompt).ids x = torch.tensor(ids).unsqueeze(0) with torch.no_grad(): for _ in range(80): logits = model(x) next_id = torch.argmax(logits[:, -1, :], dim=-1).item() x = torch.cat([x, torch.tensor([[next_id]])], dim=1) output = tokenizer.decode(x[0].tolist()) print("\nAssistant:", output.split("Assistant:")[-1].strip(), "\n")