suraj-self commited on
Commit
c430d50
·
1 Parent(s): c424ad1

update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -94
app.py CHANGED
@@ -1,104 +1,93 @@
1
- import os
 
2
  import torch
3
  import gradio as gr
 
4
  from nanochat.gpt import GPT, GPTConfig
5
- from nanochat.tokenizer import RustBPETokenizer
6
-
7
- # Configuration
8
- MODEL_PATH = "model_000971.pt"
9
- CACHE_DIR = os.path.expanduser("~/.cache/nanochat/tokenizer/")
10
- TOKENIZER_DIR = CACHE_DIR if os.path.exists(CACHE_DIR) else "."
11
-
12
- print(f"--- Waking up the Toddler ---")
13
- print(f"Loading tokenizer from: {TOKENIZER_DIR}")
14
-
15
- # 1. Load Tokenizer & Map Special Tokens
16
- tokenizer = RustBPETokenizer.from_directory(TOKENIZER_DIR)
17
-
18
- tokenizer.bos_token_id = tokenizer.enc.encode_single_token("<|bos|>")
19
- tokenizer.user_start_id = tokenizer.enc.encode_single_token("<|user_start|>")
20
- tokenizer.user_end_id = tokenizer.enc.encode_single_token("<|user_end|>")
21
- tokenizer.assistant_start_id = tokenizer.enc.encode_single_token("<|assistant_start|>")
22
- tokenizer.assistant_end_id = tokenizer.enc.encode_single_token("<|assistant_end|>")
23
-
24
- # 2. Build Model Architecture
25
- config = GPTConfig(
26
- vocab_size=32768,
27
- n_layer=12,
28
- n_head=6,
29
- n_kv_head=6,
30
- n_embd=768,
31
- sequence_len=2048,
32
- )
33
 
34
  model = GPT(config)
35
 
36
- # 3. Load Weights
37
  print("Loading weights...")
38
- state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
 
39
  state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
 
40
  model.load_state_dict(state_dict, strict=False)
41
- model.to("cpu")
42
  model.eval()
43
- print("Toddler is awake and ready!")
44
-
45
- def chat_fn(message, history):
46
- try:
47
- # 1. Build Token List
48
- tokens = [tokenizer.bos_token_id]
49
- for user_msg, assistant_msg in history:
50
- if user_msg:
51
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(user_msg) + [tokenizer.user_end_id])
52
- if assistant_msg:
53
- tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant_msg) + [tokenizer.assistant_end_id])
54
-
55
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id])
56
- tokens.append(tokenizer.assistant_start_id)
57
-
58
- input_ids = torch.tensor([tokens], dtype=torch.long)
59
-
60
- # 2. Generate (Non-streaming for stability)
61
- with torch.no_grad():
62
- # In nanochat, generate usually returns the full sequence tensor
63
- output_ids = model.generate(
64
- input_ids,
65
- max_tokens=512,
66
- temperature=0.8,
67
- top_k=40
68
- )
69
-
70
- # 3. Process Output
71
- if isinstance(output_ids, torch.Tensor):
72
- # Slicing to get only new tokens
73
- new_tokens = output_ids[0][input_ids.shape[1]:]
74
- response = tokenizer.decode(new_tokens.tolist())
75
- else:
76
- # If it's a generator, collect it all into one string
77
- response = "".join([tokenizer.decode([t]) for t in output_ids])
78
-
79
- # 4. Clean up tags
80
- for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>", "<|bos|>"]:
81
- response = response.split(tag)[0]
82
-
83
- final_text = response.strip()
84
- return final_text if final_text else "..."
85
-
86
- except Exception as e:
87
- print(f"CRITICAL ERROR: {e}")
88
- return f"Toddler tantrum: {str(e)}"
89
-
90
- # 5. Launch UI (Cleaned for Gradio 6.0 compatibility)
91
- with gr.Blocks() as demo:
92
- gr.Markdown("# 🧸 NanoChat-ClimbMix-D12")
93
- gr.ChatInterface(
94
- fn=chat_fn,
95
- examples=["Hi Toddler!", "Explain UPI.", "Tell me a joke."]
96
- )
97
-
98
- if __name__ == "__main__":
99
- # Theme moved here to resolve UserWarning
100
- demo.launch(
101
- server_name="0.0.0.0",
102
- server_port=7860,
103
- theme=gr.themes.Soft(primary_hue="orange")
104
- )
 
1
+ import json
2
+ import pickle
3
  import torch
4
  import gradio as gr
5
+
6
  from nanochat.gpt import GPT, GPTConfig
7
+
8
+ print("🚀 Loading NanoChat...")
9
+
10
+ # -----------------------
11
+ # Load tokenizer
12
+ # -----------------------
13
+
14
+ with open("tokenizer.pkl", "rb") as f:
15
+ tokenizer = pickle.load(f)
16
+
17
+ print("Tokenizer loaded")
18
+
19
+ # -----------------------
20
+ # Load model config
21
+ # -----------------------
22
+
23
+ with open("meta_000971.json") as f:
24
+ meta = json.load(f)
25
+
26
+ config = GPTConfig(**meta)
27
+
28
+ # -----------------------
29
+ # Build model
30
+ # -----------------------
 
 
 
 
31
 
32
  model = GPT(config)
33
 
 
34
  print("Loading weights...")
35
+
36
+ state_dict = torch.load("model_000971.pt", map_location="cpu")
37
  state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
38
+
39
  model.load_state_dict(state_dict, strict=False)
 
40
  model.eval()
41
+
42
+ print("✅ NanoChat ready")
43
+
44
+ # -----------------------
45
+ # Chat function
46
+ # -----------------------
47
+
48
+ def generate_reply(message, history):
49
+
50
+ tokens = [tokenizer.bos_token_id]
51
+
52
+ for user, assistant in history:
53
+ tokens += [tokenizer.user_start_id] + tokenizer.encode(user) + [tokenizer.user_end_id]
54
+ tokens += [tokenizer.assistant_start_id] + tokenizer.encode(assistant) + [tokenizer.assistant_end_id]
55
+
56
+ tokens += [tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id]
57
+ tokens.append(tokenizer.assistant_start_id)
58
+
59
+ input_ids = torch.tensor([tokens])
60
+
61
+ with torch.no_grad():
62
+ output = model.generate(
63
+ input_ids,
64
+ max_tokens=256,
65
+ temperature=0.8,
66
+ top_k=40
67
+ )
68
+
69
+ new_tokens = output[0][input_ids.shape[1]:]
70
+ text = tokenizer.decode(new_tokens.tolist())
71
+
72
+ for tag in ["<|assistant_end|>", "<|end|>"]:
73
+ text = text.split(tag)[0]
74
+
75
+ return text.strip()
76
+
77
+
78
+ # -----------------------
79
+ # UI
80
+ # -----------------------
81
+
82
+ demo = gr.ChatInterface(
83
+ fn=generate_reply,
84
+ title="🧸 NanoChat ClimbMix D12",
85
+ description="Small locally-trained NanoChat model running on HuggingFace Spaces",
86
+ examples=[
87
+ "Hi!",
88
+ "Explain UPI",
89
+ "Tell me a joke"
90
+ ],
91
+ )
92
+
93
+ demo.launch(server_name="0.0.0.0", server_port=7860)