suraj-self commited on
Commit
6383c22
·
1 Parent(s): e6eeb28
Files changed (1) hide show
  1. app.py +49 -34
app.py CHANGED
@@ -34,43 +34,58 @@ model.load_state_dict(state_dict, strict=False)
34
  model.eval()
35
 
36
  def predict(message, history):
37
- # 1. Prepare token list
38
- tokens = [tokenizer.bos_token_id]
39
- for human, assistant in history:
40
- if human:
41
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(human) + [tokenizer.user_end_id])
42
- if assistant:
43
- tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant) + [tokenizer.assistant_end_id])
44
-
45
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id])
46
- tokens.append(tokenizer.assistant_start_id)
47
-
48
- # --- THE FIX FOR ASSERTION ERROR ---
49
- # The error 'assert isinstance(tokens, list)' happens here.
50
- # We pass the tokens as a LIST, not a Tensor, to satisfy nanochat's requirements.
51
- # -----------------------------------
52
-
53
- with torch.no_grad():
54
- # Call generate with the LIST 'tokens'
55
- output = model.generate(
56
- tokens, # Passing as list [] instead of torch.tensor([[]])
57
- max_tokens=512,
58
- temperature=0.8,
59
- top_k=40
60
- )
61
 
62
- generated_text = ""
63
- # The Traceback shows model.generate is a generator (streaming)
64
- for token in output:
65
- # Handle if token is an int or a single-element tensor
66
- token_id = token if isinstance(token, int) else token.item()
67
- char = tokenizer.decode([token_id])
 
 
 
 
68
 
69
- if "<|assistant_end|>" in char:
70
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- generated_text += char
73
- yield generated_text.strip()
 
 
 
 
 
 
 
 
74
 
75
  # Launching with Gradio 6.0 compatibility
76
  demo = gr.ChatInterface(
 
34
  model.eval()
35
 
36
  def predict(message, history):
37
+ try:
38
+ # 1. Prepare token list
39
+ tokens = [tokenizer.bos_token_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # FIX: Robust history handling for Gradio 5/6
42
+ for entry in history:
43
+ # Handle list of dicts format: {"role": "user", "content": "..."}
44
+ if isinstance(entry, dict):
45
+ role = entry.get("role")
46
+ content = entry.get("content", "")
47
+ if role == "user":
48
+ tokens.extend([tokenizer.user_start_id] + tokenizer.encode(content) + [tokenizer.user_end_id])
49
+ elif role == "assistant":
50
+ tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(content) + [tokenizer.assistant_end_id])
51
 
52
+ # Handle old list of lists format: [user_msg, assistant_msg]
53
+ elif isinstance(entry, (list, tuple)):
54
+ human, assistant = entry[0], entry[1]
55
+ if human:
56
+ tokens.extend([tokenizer.user_start_id] + tokenizer.encode(human) + [tokenizer.user_end_id])
57
+ if assistant:
58
+ tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant) + [tokenizer.assistant_end_id])
59
+
60
+ # Add current user prompt
61
+ tokens.extend([tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id])
62
+ tokens.append(tokenizer.assistant_start_id)
63
+
64
+ # 2. Streaming Generation
65
+ with torch.no_grad():
66
+ # Pass as list to satisfy the nanochat assertion
67
+ output = model.generate(
68
+ tokens,
69
+ max_tokens=512,
70
+ temperature=0.8,
71
+ top_k=40
72
+ )
73
+
74
+ generated_text = ""
75
+ for token in output:
76
+ token_id = token if isinstance(token, int) else token.item()
77
+ char = tokenizer.decode([token_id])
78
 
79
+ # Stop if we hit the assistant end tag
80
+ if "<|assistant_end|>" in char or "<|end|>" in char:
81
+ break
82
+
83
+ generated_text += char
84
+ yield generated_text.strip()
85
+
86
+ except Exception as e:
87
+ print(f"CRITICAL ERROR: {e}")
88
+ yield f"Toddler tantrum: {str(e)}"
89
 
90
  # Launching with Gradio 6.0 compatibility
91
  demo = gr.ChatInterface(