hranjan043 commited on
Commit
85fbfaf
·
verified ·
1 Parent(s): 125c8f2

Update model files and add inference script

Browse files
Files changed (2) hide show
  1. README.md +10 -0
  2. inference.py +51 -0
README.md CHANGED
@@ -77,4 +77,14 @@ model = SIMGPT(
77
  state_dict = load_file("simbot.safetensors")
78
  model.load_state_dict(state_dict)
79
  model.eval()
 
 
 
 
 
 
 
 
 
 
80
 
 
77
  state_dict = load_file("simbot.safetensors")
78
  model.load_state_dict(state_dict)
79
  model.eval()
80
+ ```
81
+
82
+ ## Prompting the Model
83
+
84
+ This model is a custom PyTorch implementation and does not support the Hugging Face inference widget.
85
+
86
+ ### Interactive Usage (Recommended)
87
+
88
+ ```bash
89
+ python inference.py
90
 
inference.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from tokenizers import Tokenizer
4
+ from safetensors.torch import load_file
5
+ from model.simbot import SIMGPT
6
+
7
+ # -----------------------------
8
+ # Load tokenizer & config
9
+ # -----------------------------
10
+ tokenizer = Tokenizer.from_file("tokenizer.json")
11
+
12
+ with open("config.json") as f:
13
+ cfg = json.load(f)
14
+
15
+ # -----------------------------
16
+ # Load model
17
+ # -----------------------------
18
+ model = SIMGPT(
19
+ vocab_size=cfg["vocab_size"],
20
+ block_size=cfg["block_size"],
21
+ n_layers=cfg["n_layers"],
22
+ n_heads=cfg["n_heads"],
23
+ d_model=cfg["d_model"]
24
+ )
25
+
26
+ state_dict = load_file("simbot.safetensors")
27
+ model.load_state_dict(state_dict)
28
+ model.eval()
29
+
30
+ print("SimBot GPT ready. Type 'exit' to quit.\n")
31
+
32
+ # -----------------------------
33
+ # Interactive loop
34
+ # -----------------------------
35
+ while True:
36
+ user_input = input("User: ").strip()
37
+ if user_input.lower() in {"exit", "quit"}:
38
+ break
39
+
40
+ prompt = f"<bos>\nUser: {user_input}\nAssistant:"
41
+ ids = tokenizer.encode(prompt).ids
42
+ x = torch.tensor(ids).unsqueeze(0)
43
+
44
+ with torch.no_grad():
45
+ for _ in range(80):
46
+ logits = model(x)
47
+ next_id = torch.argmax(logits[:, -1, :], dim=-1).item()
48
+ x = torch.cat([x, torch.tensor([[next_id]])], dim=1)
49
+
50
+ output = tokenizer.decode(x[0].tolist())
51
+ print("\nAssistant:", output.split("Assistant:")[-1].strip(), "\n")