Really-Amazing commited on
Commit
5d24f6a
·
verified ·
1 Parent(s): 6fd1e97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -52
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch
2
  import gradio as gr
3
- import json # ← ONLY NEW IMPORT
4
  from nanochat.engine import Engine
5
  from nanochat.tokenizer import get_tokenizer
6
  from nanochat.gpt import GPT, GPTConfig
@@ -8,78 +7,98 @@ from nanochat.gpt import GPT, GPTConfig
8
  MODEL_PATH = "model_000971.pt"
9
 
10
  print("Waking up the toddler (NanoChat-ClimbMix-D12)...")
11
-
12
  tokenizer = get_tokenizer()
13
 
14
- print("Creating GPT model skeleton from meta_000971.json...")
15
-
16
- # === ONLY CHANGE: Load exact config from meta file (same as working space) ===
17
- with open("meta_000971.json", "r", encoding="utf-8") as f:
18
- meta_data = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- config = GPTConfig(**meta_data["model_config"])
21
  model = GPT(config)
22
- # =====================================================================
23
 
24
  print("Loading weights...")
25
  state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
26
-
27
- unwanted_prefix = '_orig_mod.'
28
- for k in list(state_dict.keys()):
29
- if k.startswith(unwanted_prefix):
30
- state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
31
 
32
  model.load_state_dict(state_dict, strict=False)
33
  model.to("cpu")
34
  model.eval()
35
  print("Model ready!")
36
 
37
- engine = Engine(model=model, tokenizer=tokenizer)
38
-
39
- # Your existing chat_fn (kept 100% unchanged)
40
  def chat_fn(message, history):
41
  try:
42
- prompt_tokens = []
 
43
  for user_msg, assistant_msg in history:
44
- prompt_tokens.extend(list(tokenizer.encode(f"<|user|>{user_msg}<|end|>")))
45
  if assistant_msg:
46
- prompt_tokens.extend(list(tokenizer.encode(f"<|assistant|>{assistant_msg}<|end|>")))
47
-
48
- prompt_tokens.extend(list(tokenizer.encode(f"<|user|>{message}<|end|><|assistant|>")))
49
-
50
- new_tokens = engine.generate(
51
- prompt_tokens,
52
- max_tokens=512,
53
- temperature=0.8,
54
- top_k=50,
55
- )
56
-
57
- if isinstance(new_tokens, tuple):
58
- new_tokens = new_tokens[0]
59
- if hasattr(new_tokens, 'tolist'):
60
- new_tokens = new_tokens.tolist()
61
-
62
- response = tokenizer.decode(new_tokens).strip()
63
-
64
- for end_tag in ["<|end|>", "<|assistant_end|>", "<|EOS|>"]:
65
- if end_tag in response:
66
- response = response.split(end_tag)[0].strip()
67
- break
68
-
69
- return response or "Toddler says: ... 😅"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  except Exception as e:
72
  return f"Toddler tantrum: {str(e)}"
73
 
74
- # Rest of your UI (unchanged)
75
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
76
- gr.Markdown("# 🧸 NanoChat-ClimbMix-D12 – Confident Toddler")
77
- gr.Markdown("Using exact config from meta_000971.json (same as working space)")
78
- gr.ChatInterface(
79
- fn=chat_fn,
80
- examples=["Tell me a joke", "What is UPI?", "Write hello world Python"],
81
- title="Chat with the Toddler"
82
- )
83
 
84
  if __name__ == "__main__":
85
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import torch
2
  import gradio as gr
 
3
  from nanochat.engine import Engine
4
  from nanochat.tokenizer import get_tokenizer
5
  from nanochat.gpt import GPT, GPTConfig
 
7
  MODEL_PATH = "model_000971.pt"
8
 
9
  print("Waking up the toddler (NanoChat-ClimbMix-D12)...")
 
10
  tokenizer = get_tokenizer()
11
 
12
+ # SET SPECIAL TOKENS (Aligned with Saint Iberis working space)
13
+ # We use .get() or try/except to handle different tokenizer versions
14
+ try:
15
+ bos_id = tokenizer.encode("<|bos|>")[0]
16
+ user_start_id = tokenizer.encode("<|user_start|>")[0]
17
+ user_end_id = tokenizer.encode("<|user_end|>")[0]
18
+ assistant_start_id = tokenizer.encode("<|assistant_start|>")[0]
19
+ assistant_end_id = tokenizer.encode("<|assistant_end|>")[0]
20
+ except:
21
+ # Fallback to standard tags if the specific Saint Iberis ones aren't in your vocab
22
+ bos_id = tokenizer.encode("<|endoftext|>")[0]
23
+ user_start_id = tokenizer.encode("<|user|>")[0]
24
+ user_end_id = tokenizer.encode("<|end|>")[0]
25
+ assistant_start_id = tokenizer.encode("<|assistant|>")[0]
26
+ assistant_end_id = tokenizer.encode("<|end|>")[0]
27
+
28
+ print("Creating GPT model skeleton (6-head, 2048 seq)...")
29
+ config = GPTConfig(
30
+ vocab_size=32768,
31
+ n_layer=12,
32
+ n_head=6,
33
+ n_kv_head=6,
34
+ n_embd=768,
35
+ sequence_len=2048,
36
+ )
37
 
 
38
  model = GPT(config)
 
39
 
40
  print("Loading weights...")
41
  state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
42
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
 
 
 
 
43
 
44
  model.load_state_dict(state_dict, strict=False)
45
  model.to("cpu")
46
  model.eval()
47
  print("Model ready!")
48
 
49
+ # We use the model directly to avoid 'Engine' type-hinting issues
 
 
50
  def chat_fn(message, history):
51
  try:
52
+ # 1. Build token list
53
+ tokens = [bos_id]
54
  for user_msg, assistant_msg in history:
55
+ tokens.extend([user_start_id] + list(tokenizer.encode(user_msg)) + [user_end_id])
56
  if assistant_msg:
57
+ tokens.extend([assistant_start_id] + list(tokenizer.encode(assistant_msg)) + [assistant_end_id])
58
+
59
+ # Current turn
60
+ tokens.extend([user_start_id] + list(tokenizer.encode(message)) + [user_end_id])
61
+ tokens.append(assistant_start_id)
62
+
63
+ # 2. THE FIX: Convert to Tensor before generating
64
+ input_ids = torch.tensor([tokens], dtype=torch.long).to("cpu")
65
+
66
+ # 3. Generate
67
+ # If your model.generate is a generator (streaming), we'll take the result
68
+ with torch.no_grad():
69
+ output_ids = model.generate(
70
+ input_ids,
71
+ max_new_tokens=512,
72
+ temperature=0.8,
73
+ top_k=50,
74
+ )
75
+
76
+ # 4. Decode (handling both streaming and blocking outputs)
77
+ # If output_ids is a generator, we collect it; if it's a tensor, we decode it.
78
+ if isinstance(output_ids, torch.Tensor):
79
+ # Take only the newly generated tokens
80
+ new_tokens = output_ids[0][input_ids.shape[1]:]
81
+ response = tokenizer.decode(new_tokens.tolist()).strip()
82
+ else:
83
+ # It's a generator (streaming)
84
+ full_response = ""
85
+ for token in output_ids:
86
+ full_response += tokenizer.decode([token])
87
+ response = full_response.strip()
88
+
89
+ # Clean up stop tags
90
+ for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]:
91
+ if tag in response:
92
+ response = response.split(tag)[0].strip()
93
+
94
+ return response or "Toddler is thinking... 😅"
95
 
96
  except Exception as e:
97
  return f"Toddler tantrum: {str(e)}"
98
 
99
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
100
+ gr.Markdown("# 🧸 NanoChat-ClimbMix-D12")
101
+ gr.ChatInterface(fn=chat_fn, title="Toddler Chat")
 
 
 
 
 
 
102
 
103
  if __name__ == "__main__":
104
  demo.launch(server_name="0.0.0.0", server_port=7860)