Really-Amazing commited on
Commit
6fd1e97
·
verified ·
1 Parent(s): e116dd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -35
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import gradio as gr
 
3
  from nanochat.engine import Engine
4
  from nanochat.tokenizer import get_tokenizer
5
  from nanochat.gpt import GPT, GPTConfig
@@ -7,22 +8,22 @@ from nanochat.gpt import GPT, GPTConfig
7
  MODEL_PATH = "model_000971.pt"
8
 
9
  print("Waking up the toddler (NanoChat-ClimbMix-D12)...")
 
10
  tokenizer = get_tokenizer()
11
 
12
- # EXACT values from your local scripts.chat_web output
13
- config = GPTConfig(
14
- vocab_size=32768,
15
- n_layer=12,
16
- n_head=6,
17
- n_kv_head=6,
18
- n_embd=768,
19
- sequence_len=2048,
20
- )
21
 
 
22
  model = GPT(config)
 
23
 
24
  print("Loading weights...")
25
  state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
 
26
  unwanted_prefix = '_orig_mod.'
27
  for k in list(state_dict.keys()):
28
  if k.startswith(unwanted_prefix):
@@ -35,14 +36,7 @@ print("Model ready!")
35
 
36
  engine = Engine(model=model, tokenizer=tokenizer)
37
 
38
- def safe_encode(text):
39
- """Helper to ensure we only get the list of token IDs."""
40
- encoded = tokenizer.encode(text)
41
- # If it's a tuple (ids, mask), just take the ids
42
- if isinstance(encoded, tuple):
43
- return list(encoded[0])
44
- return list(encoded)
45
-
46
  def chat_fn(message, history):
47
  try:
48
  prompt_tokens = []
@@ -53,30 +47,20 @@ def chat_fn(message, history):
53
 
54
  prompt_tokens.extend(list(tokenizer.encode(f"<|user|>{message}<|end|><|assistant|>")))
55
 
56
- # Generate and handle possible tuple return
57
- gen_output = engine.generate(
58
  prompt_tokens,
59
  max_tokens=512,
60
  temperature=0.8,
61
  top_k=50,
62
  )
63
 
64
- # Unwrap if it's a tuple (common in batched/speculative forks)
65
- if isinstance(gen_output, tuple):
66
- new_tokens = gen_output[0] # usually first is tokens
67
- print("Unwrapped tuple from generate:", type(new_tokens))
68
- else:
69
- new_tokens = gen_output
70
-
71
- # Convert to list if tensor
72
  if hasattr(new_tokens, 'tolist'):
73
  new_tokens = new_tokens.tolist()
74
- elif not isinstance(new_tokens, list):
75
- new_tokens = list(new_tokens)
76
 
77
  response = tokenizer.decode(new_tokens).strip()
78
 
79
- # Clean end tag
80
  for end_tag in ["<|end|>", "<|assistant_end|>", "<|EOS|>"]:
81
  if end_tag in response:
82
  response = response.split(end_tag)[0].strip()
@@ -85,16 +69,15 @@ def chat_fn(message, history):
85
  return response or "Toddler says: ... 😅"
86
 
87
  except Exception as e:
88
- import traceback
89
- print(traceback.format_exc()) # log full stack for debug
90
  return f"Toddler tantrum: {str(e)}"
91
 
 
92
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
93
- gr.Markdown("# 🧸 NanoChat-ClimbMix-D12")
94
- gr.Markdown("Architecture verified. Tuple/Generator issues handled.")
95
  gr.ChatInterface(
96
  fn=chat_fn,
97
- examples=["Tell me a joke", "What is UPI?"],
98
  title="Chat with the Toddler"
99
  )
100
 
 
1
  import torch
2
  import gradio as gr
3
+ import json # ← ONLY NEW IMPORT
4
  from nanochat.engine import Engine
5
  from nanochat.tokenizer import get_tokenizer
6
  from nanochat.gpt import GPT, GPTConfig
 
8
  MODEL_PATH = "model_000971.pt"
9
 
10
  print("Waking up the toddler (NanoChat-ClimbMix-D12)...")
11
+
12
  tokenizer = get_tokenizer()
13
 
14
+ print("Creating GPT model skeleton from meta_000971.json...")
15
+
16
+ # === ONLY CHANGE: Load exact config from meta file (same as working space) ===
17
+ with open("meta_000971.json", "r", encoding="utf-8") as f:
18
+ meta_data = json.load(f)
 
 
 
 
19
 
20
+ config = GPTConfig(**meta_data["model_config"])
21
  model = GPT(config)
22
+ # =====================================================================
23
 
24
  print("Loading weights...")
25
  state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
26
+
27
  unwanted_prefix = '_orig_mod.'
28
  for k in list(state_dict.keys()):
29
  if k.startswith(unwanted_prefix):
 
36
 
37
  engine = Engine(model=model, tokenizer=tokenizer)
38
 
39
+ # Your existing chat_fn (kept 100% unchanged)
 
 
 
 
 
 
 
40
  def chat_fn(message, history):
41
  try:
42
  prompt_tokens = []
 
47
 
48
  prompt_tokens.extend(list(tokenizer.encode(f"<|user|>{message}<|end|><|assistant|>")))
49
 
50
+ new_tokens = engine.generate(
 
51
  prompt_tokens,
52
  max_tokens=512,
53
  temperature=0.8,
54
  top_k=50,
55
  )
56
 
57
+ if isinstance(new_tokens, tuple):
58
+ new_tokens = new_tokens[0]
 
 
 
 
 
 
59
  if hasattr(new_tokens, 'tolist'):
60
  new_tokens = new_tokens.tolist()
 
 
61
 
62
  response = tokenizer.decode(new_tokens).strip()
63
 
 
64
  for end_tag in ["<|end|>", "<|assistant_end|>", "<|EOS|>"]:
65
  if end_tag in response:
66
  response = response.split(end_tag)[0].strip()
 
69
  return response or "Toddler says: ... 😅"
70
 
71
  except Exception as e:
 
 
72
  return f"Toddler tantrum: {str(e)}"
73
 
74
+ # Rest of your UI (unchanged)
75
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
76
+ gr.Markdown("# 🧸 NanoChat-ClimbMix-D12 – Confident Toddler")
77
+ gr.Markdown("Using exact config from meta_000971.json (same as working space)")
78
  gr.ChatInterface(
79
  fn=chat_fn,
80
+ examples=["Tell me a joke", "What is UPI?", "Write hello world Python"],
81
  title="Chat with the Toddler"
82
  )
83