Really-Amazing commited on
Commit
5c428ee
·
verified ·
1 Parent(s): de657c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -23
app.py CHANGED
@@ -2,47 +2,61 @@ import torch
2
  import gradio as gr
3
  from nanochat.engine import Engine
4
  from nanochat.tokenizer import get_tokenizer
 
5
 
6
  MODEL_PATH = "model_000971.pt"
7
 
8
  print("Waking up the toddler (NanoChat-ClimbMix-D12)...")
9
 
 
10
  tokenizer = get_tokenizer()
11
 
12
- print("Loading checkpoint directly...")
13
- checkpoint = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
14
-
15
- # Your checkpoint is a flat state_dict with 'transformer.' prefix
16
- # So we need the model class instance first
17
-
18
- # Option 1: If nanochat has a from_checkpoint or load method
19
- # (most likely in checkpoint_manager or engine)
20
- try:
21
- from nanochat.checkpoint_manager import load_model
22
- model, _ = load_model(".", checkpoint_name="model_000971.pt", device="cpu")
23
- except Exception as e:
24
- print(f"checkpoint_manager failed: {e}")
25
- # Option 2: Direct load if checkpoint is state_dict
26
- state_dict = checkpoint
27
-
28
- # We need a pre-initialized model to load into
29
- # Since we can't build GPT without args, assume Engine can help or fallback
30
- # For now, raise to see
31
- raise ValueError("Cannot reconstruct model — checkpoint is flat state_dict. Need model skeleton or load method")
 
 
 
 
 
 
32
 
33
  model.to("cpu")
34
  model.eval()
35
 
36
- print("Model loaded!")
37
 
38
  engine = Engine(model=model, tokenizer=tokenizer)
39
 
40
  def chat_fn(message, history):
 
41
  return engine.generate(message, max_tokens=512, temperature=0.85)
42
 
43
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
44
- gr.Markdown("# 🧸 NanoChat-ClimbMix-D12")
45
- gr.ChatInterface(fn=chat_fn)
 
 
 
 
 
46
 
47
  if __name__ == "__main__":
48
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  import gradio as gr
3
  from nanochat.engine import Engine
4
  from nanochat.tokenizer import get_tokenizer
5
+ from nanochat.gpt import GPT # ← correct class
6
 
7
  MODEL_PATH = "model_000971.pt"
8
 
9
  print("Waking up the toddler (NanoChat-ClimbMix-D12)...")
10
 
11
+ # Tokenizer (Docker fix already placed tokenizer.pkl correctly)
12
  tokenizer = get_tokenizer()
13
 
14
+ print("Creating GPT model skeleton (D12 fallback)...")
15
+
16
+ # Create blank model — use positional arguments (common in nanochat forks)
17
+ # Order usually: vocab_size, n_layer, n_head, n_embd, block_size, dropout, ...
18
+ model = GPT(
19
+ vocab_size=50257, # GPT-2 base — most common
20
+ n_layer=12,
21
+ n_head=12,
22
+ n_embd=768,
23
+ block_size=1024,
24
+ dropout=0.1,
25
+ # If error about missing arg, add bias=True or other defaults here
26
+ )
27
+
28
+ print("Loading flat state_dict from checkpoint...")
29
+ state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
30
+
31
+ # Clean torch.compile prefix if present
32
+ unwanted_prefix = '_orig_mod.'
33
+ for k in list(state_dict.keys()):
34
+ if k.startswith(unwanted_prefix):
35
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
36
+
37
+ # Load — strict=False ignores extra keys (value_embeds, lambdas, etc.)
38
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
39
+ print(f"Load info: {len(missing)} missing keys, {len(unexpected)} unexpected keys")
40
 
41
  model.to("cpu")
42
  model.eval()
43
 
44
+ print("Model ready!")
45
 
46
  engine = Engine(model=model, tokenizer=tokenizer)
47
 
48
  def chat_fn(message, history):
49
+ # Use max_tokens as per your engine.py grep
50
  return engine.generate(message, max_tokens=512, temperature=0.85)
51
 
52
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
53
+ gr.Markdown("# 🧸 NanoChat-ClimbMix-D12 – Toddler Phase")
54
+ gr.Markdown("Confident, funny, wildly inaccurate. Maturing fast → D14/D16/D18 soon!")
55
+ gr.ChatInterface(
56
+ fn=chat_fn,
57
+ examples=["Why is the sky blue?", "What is UPI?", "Write hello world Python"],
58
+ title="Chat with the Toddler"
59
+ )
60
 
61
  if __name__ == "__main__":
62
  demo.launch(server_name="0.0.0.0", server_port=7860)