Really-Amazing commited on
Commit
72950d2
·
verified ·
1 Parent(s): 57899cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -21
app.py CHANGED
@@ -6,9 +6,7 @@ from nanochat.tokenizer import RustBPETokenizer
6
 
7
  # Configuration
8
  MODEL_PATH = "model_000971.pt"
9
- # The Dockerfile moves files to this specific cache location
10
  CACHE_DIR = os.path.expanduser("~/.cache/nanochat/tokenizer/")
11
- # Fallback to current directory if cache doesn't exist (local testing)
12
  TOKENIZER_DIR = CACHE_DIR if os.path.exists(CACHE_DIR) else "."
13
 
14
  print(f"--- Waking up the Toddler ---")
@@ -17,7 +15,6 @@ print(f"Loading tokenizer from: {TOKENIZER_DIR}")
17
  # 1. Load Tokenizer & Map Special Tokens
18
  tokenizer = RustBPETokenizer.from_directory(TOKENIZER_DIR)
19
 
20
- # These must match your training vocab
21
  tokenizer.bos_token_id = tokenizer.enc.encode_single_token("<|bos|>")
22
  tokenizer.user_start_id = tokenizer.enc.encode_single_token("<|user_start|>")
23
  tokenizer.user_end_id = tokenizer.enc.encode_single_token("<|user_end|>")
@@ -36,7 +33,7 @@ config = GPTConfig(
36
 
37
  model = GPT(config)
38
 
39
- # 3. Load Weights (with _orig_mod. cleaning)
40
  print("Loading weights...")
41
  state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
42
  state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
@@ -47,10 +44,11 @@ print("Toddler is awake and ready!")
47
 
48
  def chat_fn(message, history):
49
  try:
50
- # Build Chat History
51
  tokens = [tokenizer.bos_token_id]
52
  for user_msg, assistant_msg in history:
53
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(user_msg) + [tokenizer.user_end_id])
 
54
  if assistant_msg:
55
  tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant_msg) + [tokenizer.assistant_end_id])
56
 
@@ -60,9 +58,7 @@ def chat_fn(message, history):
60
 
61
  input_ids = torch.tensor([tokens], dtype=torch.long)
62
 
63
- # 4. Generate with Streaming Logic
64
- # Note: In nanochat.gpt, generate is typically an autoregressive loop.
65
- # If your version returns a generator, we iterate. If a tensor, we slice.
66
  with torch.no_grad():
67
  output_ids = model.generate(
68
  input_ids,
@@ -71,41 +67,42 @@ def chat_fn(message, history):
71
  top_k=40
72
  )
73
 
74
- # Handle Tensor vs Generator output
75
  if isinstance(output_ids, torch.Tensor):
76
- # Just take the new parts
77
  new_tokens = output_ids[0][input_ids.shape[1]:]
78
  response = tokenizer.decode(new_tokens.tolist())
79
  else:
80
- # It's a generator yielding token by token
81
  response = ""
82
  for token in output_ids:
83
  decoded = tokenizer.decode([token])
84
  if "<|assistant_end|>" in decoded:
85
  break
86
  response += decoded
87
- yield response # Yield for streaming UI effect
88
 
89
- # Final cleanup for non-streaming return
90
  for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]:
91
  response = response.split(tag)[0]
92
 
93
  return response.strip()
94
 
95
  except Exception as e:
96
- # Crucial for QA: see the actual error in Space logs
97
  print(f"ERROR: {e}")
98
  return f"Toddler tantrum: {str(e)}"
99
 
100
- # 5. Launch UI
101
- with gr.Blocks(theme=gr.themes.Default(primary_hue="orange")) as demo:
102
  gr.Markdown("# 🧸 NanoChat-ClimbMix-D12")
103
- gr.Markdown("A custom-trained small language model running on your CPU.")
104
  gr.ChatInterface(
105
  fn=chat_fn,
106
- type="messages", # Updated for latest Gradio versions
107
- examples=["Hi Toddler!", "How does UPI work?", "Tell me a story."]
108
  )
109
 
110
  if __name__ == "__main__":
111
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
6
 
7
  # Configuration
8
  MODEL_PATH = "model_000971.pt"
 
9
  CACHE_DIR = os.path.expanduser("~/.cache/nanochat/tokenizer/")
 
10
  TOKENIZER_DIR = CACHE_DIR if os.path.exists(CACHE_DIR) else "."
11
 
12
  print(f"--- Waking up the Toddler ---")
 
15
  # 1. Load Tokenizer & Map Special Tokens
16
  tokenizer = RustBPETokenizer.from_directory(TOKENIZER_DIR)
17
 
 
18
  tokenizer.bos_token_id = tokenizer.enc.encode_single_token("<|bos|>")
19
  tokenizer.user_start_id = tokenizer.enc.encode_single_token("<|user_start|>")
20
  tokenizer.user_end_id = tokenizer.enc.encode_single_token("<|user_end|>")
 
33
 
34
  model = GPT(config)
35
 
36
+ # 3. Load Weights
37
  print("Loading weights...")
38
  state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
39
  state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
 
44
 
45
  def chat_fn(message, history):
46
  try:
47
+ # Build Chat History (Handling standard Gradio list-of-lists format)
48
  tokens = [tokenizer.bos_token_id]
49
  for user_msg, assistant_msg in history:
50
+ if user_msg:
51
+ tokens.extend([tokenizer.user_start_id] + tokenizer.encode(user_msg) + [tokenizer.user_end_id])
52
  if assistant_msg:
53
  tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant_msg) + [tokenizer.assistant_end_id])
54
 
 
58
 
59
  input_ids = torch.tensor([tokens], dtype=torch.long)
60
 
61
+ # 4. Generate
 
 
62
  with torch.no_grad():
63
  output_ids = model.generate(
64
  input_ids,
 
67
  top_k=40
68
  )
69
 
70
+ # Handle output
71
  if isinstance(output_ids, torch.Tensor):
 
72
  new_tokens = output_ids[0][input_ids.shape[1]:]
73
  response = tokenizer.decode(new_tokens.tolist())
74
  else:
75
+ # Generator logic
76
  response = ""
77
  for token in output_ids:
78
  decoded = tokenizer.decode([token])
79
  if "<|assistant_end|>" in decoded:
80
  break
81
  response += decoded
82
+ yield response
83
 
84
+ # Final cleanup
85
  for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]:
86
  response = response.split(tag)[0]
87
 
88
  return response.strip()
89
 
90
  except Exception as e:
 
91
  print(f"ERROR: {e}")
92
  return f"Toddler tantrum: {str(e)}"
93
 
94
+ # 5. Launch UI (Cleaned for Gradio 6.0 compatibility)
95
+ with gr.Blocks() as demo:
96
  gr.Markdown("# 🧸 NanoChat-ClimbMix-D12")
 
97
  gr.ChatInterface(
98
  fn=chat_fn,
99
+ examples=["Hi Toddler!", "Explain UPI.", "Tell me a joke."]
 
100
  )
101
 
102
  if __name__ == "__main__":
103
+ # Theme moved here to resolve UserWarning
104
+ demo.launch(
105
+ server_name="0.0.0.0",
106
+ server_port=7860,
107
+ theme=gr.themes.Soft(primary_hue="orange")
108
+ )