File size: 1,386 Bytes
85fbfaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import json
from tokenizers import Tokenizer
from safetensors.torch import load_file
from model.simbot import SIMGPT

# -----------------------------
# Load tokenizer & config
# -----------------------------
tokenizer = Tokenizer.from_file("tokenizer.json")

with open("config.json") as f:
    cfg = json.load(f)

# -----------------------------
# Load model
# -----------------------------
model = SIMGPT(
    vocab_size=cfg["vocab_size"],
    block_size=cfg["block_size"],
    n_layers=cfg["n_layers"],
    n_heads=cfg["n_heads"],
    d_model=cfg["d_model"]
)

state_dict = load_file("simbot.safetensors")
model.load_state_dict(state_dict)
model.eval()

print("SimBot GPT ready. Type 'exit' to quit.\n")

# -----------------------------
# Interactive loop
# -----------------------------
while True:
    user_input = input("User: ").strip()
    if user_input.lower() in {"exit", "quit"}:
        break

    prompt = f"<bos>\nUser: {user_input}\nAssistant:"
    ids = tokenizer.encode(prompt).ids
    x = torch.tensor(ids).unsqueeze(0)

    with torch.no_grad():
        for _ in range(80):
            logits = model(x)
            next_id = torch.argmax(logits[:, -1, :], dim=-1).item()
            x = torch.cat([x, torch.tensor([[next_id]])], dim=1)

    output = tokenizer.decode(x[0].tolist())
    print("\nAssistant:", output.split("Assistant:")[-1].strip(), "\n")