suraj-self commited on
Commit
e6eeb28
·
1 Parent(s): db8631a
Files changed (1) hide show
  1. app.py +40 -54
app.py CHANGED
@@ -4,44 +4,18 @@ import gradio as gr
4
  from nanochat.gpt import GPT, GPTConfig
5
  from nanochat.tokenizer import RustBPETokenizer
6
 
7
- # Aggressive Path Finding
8
- # Since you have files in the root, we check '.' first
9
- possible_paths = [
10
- ".",
11
- "/app",
12
- os.path.expanduser("~/.cache/nanochat/tokenizer/")
13
- ]
14
-
15
- TOKENIZER_DIR = None
16
- for p in possible_paths:
17
- if os.path.exists(os.path.join(p, "token_bytes.pt")):
18
- TOKENIZER_DIR = p
19
- break
20
-
21
- if not TOKENIZER_DIR:
22
- # If still not found, we use root as a fallback
23
- TOKENIZER_DIR = "."
24
 
25
  print(f"--- System Initialization ---")
26
- print(f"Loading tokenizer from: {os.path.abspath(TOKENIZER_DIR)}")
27
-
28
- # Load Tokenizer
29
  tokenizer = RustBPETokenizer.from_directory(TOKENIZER_DIR)
30
 
31
- # Map IDs (These MUST exist in your vocabulary)
32
- try:
33
- tokenizer.bos_token_id = tokenizer.enc.encode_single_token("<|bos|>")
34
- tokenizer.user_start_id = tokenizer.enc.encode_single_token("<|user_start|>")
35
- tokenizer.user_end_id = tokenizer.enc.encode_single_token("<|user_end|>")
36
- tokenizer.assistant_start_id = tokenizer.enc.encode_single_token("<|assistant_start|>")
37
- tokenizer.assistant_end_id = tokenizer.enc.encode_single_token("<|assistant_end|>")
38
- except Exception as e:
39
- print(f"Warning: Special tokens not found in vocab. Error: {e}")
40
- # Fallback to standard GPT-2 tokens if yours are missing
41
- tokenizer.bos_token_id = 50256
42
- tokenizer.user_start_id = 50257
43
- tokenizer.user_end_id = 50258
44
- tokenizer.assistant_start_id = 50259
45
 
46
  # Model Setup
47
  config = GPTConfig(
@@ -53,7 +27,6 @@ config = GPTConfig(
53
  )
54
 
55
  model = GPT(config)
56
-
57
  print("Loading model weights...")
58
  state_dict = torch.load("model_000971.pt", map_location="cpu")
59
  state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
@@ -61,37 +34,50 @@ model.load_state_dict(state_dict, strict=False)
61
  model.eval()
62
 
63
  def predict(message, history):
 
64
  tokens = [tokenizer.bos_token_id]
65
  for human, assistant in history:
66
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(human) + [tokenizer.user_end_id])
 
67
  if assistant:
68
  tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant) + [tokenizer.assistant_end_id])
69
 
70
  tokens.extend([tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id])
71
  tokens.append(tokenizer.assistant_start_id)
72
 
73
- input_ids = torch.tensor([tokens], dtype=torch.long)
 
 
 
74
 
75
  with torch.no_grad():
76
- output = model.generate(input_ids, max_tokens=512, temperature=0.8)
 
 
 
 
 
 
77
 
78
- # Generator vs Tensor handling
79
- if isinstance(output, torch.Tensor):
80
- new_tokens = output[0][input_ids.shape[1]:]
81
- response = tokenizer.decode(new_tokens.tolist())
82
- for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]:
83
- response = response.split(tag)[0]
84
- yield response.strip()
85
- else:
86
- generated_text = ""
87
- for token in output:
88
- token_id = token if isinstance(token, int) else token.item()
89
- char = tokenizer.decode([token_id])
90
- if "<|assistant_end|>" in char: break
91
- generated_text += char
92
- yield generated_text.strip()
93
 
94
- demo = gr.ChatInterface(fn=predict, title="🧸 NanoChat-D12")
 
 
 
 
 
95
 
96
  if __name__ == "__main__":
97
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
4
  from nanochat.gpt import GPT, GPTConfig
5
  from nanochat.tokenizer import RustBPETokenizer
6
 
7
+ # Files are in the root of the space
8
+ TOKENIZER_DIR = "."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  print(f"--- System Initialization ---")
 
 
 
11
  tokenizer = RustBPETokenizer.from_directory(TOKENIZER_DIR)
12
 
13
+ # Map Special Tokens
14
+ tokenizer.bos_token_id = tokenizer.enc.encode_single_token("<|bos|>")
15
+ tokenizer.user_start_id = tokenizer.enc.encode_single_token("<|user_start|>")
16
+ tokenizer.user_end_id = tokenizer.enc.encode_single_token("<|user_end|>")
17
+ tokenizer.assistant_start_id = tokenizer.enc.encode_single_token("<|assistant_start|>")
18
+ tokenizer.assistant_end_id = tokenizer.enc.encode_single_token("<|assistant_end|>")
 
 
 
 
 
 
 
 
19
 
20
  # Model Setup
21
  config = GPTConfig(
 
27
  )
28
 
29
  model = GPT(config)
 
30
  print("Loading model weights...")
31
  state_dict = torch.load("model_000971.pt", map_location="cpu")
32
  state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
 
34
  model.eval()
35
 
36
  def predict(message, history):
37
+ # 1. Prepare token list
38
  tokens = [tokenizer.bos_token_id]
39
  for human, assistant in history:
40
+ if human:
41
+ tokens.extend([tokenizer.user_start_id] + tokenizer.encode(human) + [tokenizer.user_end_id])
42
  if assistant:
43
  tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant) + [tokenizer.assistant_end_id])
44
 
45
  tokens.extend([tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id])
46
  tokens.append(tokenizer.assistant_start_id)
47
 
48
+ # --- THE FIX FOR ASSERTION ERROR ---
49
+ # The error 'assert isinstance(tokens, list)' happens here.
50
+ # We pass the tokens as a LIST, not a Tensor, to satisfy nanochat's requirements.
51
+ # -----------------------------------
52
 
53
  with torch.no_grad():
54
+ # Call generate with the LIST 'tokens'
55
+ output = model.generate(
56
+ tokens, # Passing as list [] instead of torch.tensor([[]])
57
+ max_tokens=512,
58
+ temperature=0.8,
59
+ top_k=40
60
+ )
61
 
62
+ generated_text = ""
63
+ # The Traceback shows model.generate is a generator (streaming)
64
+ for token in output:
65
+ # Handle if token is an int or a single-element tensor
66
+ token_id = token if isinstance(token, int) else token.item()
67
+ char = tokenizer.decode([token_id])
68
+
69
+ if "<|assistant_end|>" in char:
70
+ break
71
+
72
+ generated_text += char
73
+ yield generated_text.strip()
 
 
 
74
 
75
+ # Launching with Gradio 6.0 compatibility
76
+ demo = gr.ChatInterface(
77
+ fn=predict,
78
+ title="🧸 NanoChat-D12",
79
+ description="Running on CPU. Optimized for Saint Iberis weights."
80
+ )
81
 
82
  if __name__ == "__main__":
83
  demo.launch(server_name="0.0.0.0", server_port=7860)