Really-Amazing commited on
Commit
de657c0
·
verified ·
1 Parent(s): 00d2698

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -32
app.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  import gradio as gr
3
  from nanochat.engine import Engine
4
  from nanochat.tokenizer import get_tokenizer
5
- from nanochat.gpt import GPT
6
 
7
  MODEL_PATH = "model_000971.pt"
8
 
@@ -10,37 +9,31 @@ print("Waking up the toddler (NanoChat-ClimbMix-D12)...")
10
 
11
  tokenizer = get_tokenizer()
12
 
13
- print("Building GPT model skeleton (D12 fallback)...")
14
- config = {
15
- "n_layer": 12,
16
- "n_head": 12,
17
- "n_embd": 768,
18
- "block_size": 1024,
19
- "vocab_size": 50257, # GPT-2 standard — safer bet
20
- "dropout": 0.1,
21
- "bias": True,
22
- }
23
- model = GPT(**config)
24
-
25
- print("Loading weights from checkpoint...")
26
  checkpoint = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
27
 
28
- state_dict = checkpoint if not isinstance(checkpoint, dict) else (
29
- checkpoint.get('model') or checkpoint.get('state_dict') or checkpoint
30
- )
31
 
32
- unwanted_prefix = '_orig_mod.'
33
- for k in list(state_dict.keys()):
34
- if k.startswith(unwanted_prefix):
35
- state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
 
 
 
 
 
36
 
37
- missing, unexpected = model.load_state_dict(state_dict, strict=False)
38
- if missing or unexpected:
39
- print(f"Warning: Missing keys: {len(missing)} | Unexpected: {len(unexpected)}")
 
40
 
41
  model.to("cpu")
42
  model.eval()
43
- print("Model loaded successfully!")
 
44
 
45
  engine = Engine(model=model, tokenizer=tokenizer)
46
 
@@ -48,13 +41,8 @@ def chat_fn(message, history):
48
  return engine.generate(message, max_tokens=512, temperature=0.85)
49
 
50
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
51
- gr.Markdown("# 🧸 NanoChat-ClimbMix-D12 – The Confident Toddler LLM")
52
- gr.Markdown("Karpathy nanochat fork. Preschool phase: bold, funny, often wrong. 😂\nRoadmap: D14 → D16 → D18 → D20+")
53
- gr.ChatInterface(
54
- fn=chat_fn,
55
- examples=["Why is the sky blue?", "What is UPI?", "Write hello world Python code"],
56
- title="Chat with the Toddler"
57
- )
58
 
59
  if __name__ == "__main__":
60
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  import gradio as gr
3
  from nanochat.engine import Engine
4
  from nanochat.tokenizer import get_tokenizer
 
5
 
6
  MODEL_PATH = "model_000971.pt"
7
 
 
9
 
10
  tokenizer = get_tokenizer()
11
 
12
+ print("Loading checkpoint directly...")
 
 
 
 
 
 
 
 
 
 
 
 
13
  checkpoint = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
14
 
15
+ # Your checkpoint is a flat state_dict with 'transformer.' prefix
16
+ # So we need the model class instance first
 
17
 
18
+ # Option 1: If nanochat has a from_checkpoint or load method
19
+ # (most likely in checkpoint_manager or engine)
20
+ try:
21
+ from nanochat.checkpoint_manager import load_model
22
+ model, _ = load_model(".", checkpoint_name="model_000971.pt", device="cpu")
23
+ except Exception as e:
24
+ print(f"checkpoint_manager failed: {e}")
25
+ # Option 2: Direct load if checkpoint is state_dict
26
+ state_dict = checkpoint
27
 
28
+ # We need a pre-initialized model to load into
29
+ # Since we can't build GPT without args, assume Engine can help or fallback
30
+ # For now, raise to see
31
+ raise ValueError("Cannot reconstruct model — checkpoint is flat state_dict. Need model skeleton or load method")
32
 
33
  model.to("cpu")
34
  model.eval()
35
+
36
+ print("Model loaded!")
37
 
38
  engine = Engine(model=model, tokenizer=tokenizer)
39
 
 
41
  return engine.generate(message, max_tokens=512, temperature=0.85)
42
 
43
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
44
+ gr.Markdown("# 🧸 NanoChat-ClimbMix-D12")
45
+ gr.ChatInterface(fn=chat_fn)
 
 
 
 
 
46
 
47
  if __name__ == "__main__":
48
  demo.launch(server_name="0.0.0.0", server_port=7860)