Really-Amazing commited on
Commit
10eadd6
·
verified ·
1 Parent(s): 5c428ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import gradio as gr
3
  from nanochat.engine import Engine
4
  from nanochat.tokenizer import get_tokenizer
5
- from nanochat.gpt import GPT # ← correct class
6
 
7
  MODEL_PATH = "model_000971.pt"
8
 
@@ -13,18 +13,20 @@ tokenizer = get_tokenizer()
13
 
14
  print("Creating GPT model skeleton (D12 fallback)...")
15
 
16
- # Create blank model use positional arguments (common in nanochat forks)
17
- # Order usually: vocab_size, n_layer, n_head, n_embd, block_size, dropout, ...
18
- model = GPT(
19
- vocab_size=50257, # GPT-2 base — most common
20
  n_layer=12,
21
  n_head=12,
22
  n_embd=768,
23
  block_size=1024,
24
  dropout=0.1,
25
- # If error about missing arg, add bias=True or other defaults here
26
  )
27
 
 
 
 
28
  print("Loading flat state_dict from checkpoint...")
29
  state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
30
 
@@ -34,7 +36,7 @@ for k in list(state_dict.keys()):
34
  if k.startswith(unwanted_prefix):
35
  state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
36
 
37
- # Load — strict=False ignores extra keys (value_embeds, lambdas, etc.)
38
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
39
  print(f"Load info: {len(missing)} missing keys, {len(unexpected)} unexpected keys")
40
 
@@ -46,7 +48,6 @@ print("Model ready!")
46
  engine = Engine(model=model, tokenizer=tokenizer)
47
 
48
  def chat_fn(message, history):
49
- # Use max_tokens as per your engine.py grep
50
  return engine.generate(message, max_tokens=512, temperature=0.85)
51
 
52
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
 
2
  import gradio as gr
3
  from nanochat.engine import Engine
4
  from nanochat.tokenizer import get_tokenizer
5
+ from nanochat.gpt import GPT, GPTConfig # ← Added GPTConfig here!
6
 
7
  MODEL_PATH = "model_000971.pt"
8
 
 
13
 
14
  print("Creating GPT model skeleton (D12 fallback)...")
15
 
16
+ # 1. Create the config object first
17
+ # (50304 is the standard padded vocab size in nanoGPT for efficiency)
18
+ config = GPTConfig(
19
+ vocab_size=50304,
20
  n_layer=12,
21
  n_head=12,
22
  n_embd=768,
23
  block_size=1024,
24
  dropout=0.1,
 
25
  )
26
 
27
+ # 2. Pass the config object into the GPT class
28
+ model = GPT(config)
29
+
30
  print("Loading flat state_dict from checkpoint...")
31
  state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
32
 
 
36
  if k.startswith(unwanted_prefix):
37
  state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
38
 
39
+ # Load — strict=False ignores extra keys
40
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
41
  print(f"Load info: {len(missing)} missing keys, {len(unexpected)} unexpected keys")
42
 
 
48
  engine = Engine(model=model, tokenizer=tokenizer)
49
 
50
  def chat_fn(message, history):
 
51
  return engine.generate(message, max_tokens=512, temperature=0.85)
52
 
53
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: