suraj-self commited on
Commit
1efbc24
·
1 Parent(s): a5f28af
Files changed (1) hide show
  1. app.py +14 -29
app.py CHANGED
@@ -35,39 +35,24 @@ model.eval()
35
 
36
  def predict(message, history):
37
  try:
38
- # 1. Prepare token list
 
39
  tokens = [tokenizer.bos_token_id]
40
 
41
- # FIX: Explicitly extract 'content' string from Gradio objects
42
- for entry in history:
43
- if isinstance(entry, dict):
44
- # Gradio 5/6 format: {"role": "user", "content": "..."}
45
- role = entry.get("role")
46
- content = entry.get("content", "")
47
- if role == "user":
48
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(str(content)) + [tokenizer.user_end_id])
49
- elif role == "assistant":
50
- tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(str(content)) + [tokenizer.assistant_end_id])
51
-
52
- elif isinstance(entry, (list, tuple)):
53
- # Legacy format: [user_msg, assistant_msg]
54
- user_content, assistant_content = entry[0], entry[1]
55
- if user_content:
56
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(str(user_content)) + [tokenizer.user_end_id])
57
- if assistant_content:
58
- tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(str(assistant_content)) + [tokenizer.assistant_end_id])
59
-
60
- # 2. Add current user prompt
61
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(str(message)) + [tokenizer.user_end_id])
62
  tokens.append(tokenizer.assistant_start_id)
63
 
64
- # 3. Streaming Generation
65
  with torch.no_grad():
66
- # Pass as list to satisfy the nanochat assertion check
67
  output = model.generate(
68
  tokens,
69
  max_tokens=512,
70
- temperature=0.8,
71
  top_k=40
72
  )
73
 
@@ -76,17 +61,17 @@ def predict(message, history):
76
  token_id = token if isinstance(token, int) else token.item()
77
  char = tokenizer.decode([token_id])
78
 
79
- # Stop tags to prevent the model from talking to itself
80
  if any(tag in char for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]):
81
  break
82
 
83
  generated_text += char
 
84
  yield generated_text.strip()
85
 
86
  except Exception as e:
87
- # Log the exact error to the console for QA debugging
88
- print(f"Error details: {str(e)}")
89
- yield f"Toddler tantrum: {str(e)}"
90
 
91
  # Launching with Gradio 6.0 compatibility
92
  demo = gr.ChatInterface(
 
35
 
36
  def predict(message, history):
37
  try:
38
+ # 1. Stateless Prompt Construction
39
+ # We completely ignore 'history' to prevent the model from repeating old answers.
40
  tokens = [tokenizer.bos_token_id]
41
 
42
+ # We only encode the CURRENT message
43
+ user_content = str(message).strip()
44
+ tokens.extend([tokenizer.user_start_id] + tokenizer.encode(user_content) + [tokenizer.user_end_id])
45
+
46
+ # Add the signal for the assistant to start talking
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  tokens.append(tokenizer.assistant_start_id)
48
 
49
+ # 2. Streaming Generation
50
  with torch.no_grad():
51
+ # Pass as a Python list to satisfy the nanochat engine assertion
52
  output = model.generate(
53
  tokens,
54
  max_tokens=512,
55
+ temperature=0.8, # You can try 0.7 for more factual answers
56
  top_k=40
57
  )
58
 
 
61
  token_id = token if isinstance(token, int) else token.item()
62
  char = tokenizer.decode([token_id])
63
 
64
+ # Check for stop tags in the character stream
65
  if any(tag in char for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]):
66
  break
67
 
68
  generated_text += char
69
+ # Yielding the text as it generates for that "real-time" feel
70
  yield generated_text.strip()
71
 
72
  except Exception as e:
73
+ print(f"Stateless Predict Error: {str(e)}")
74
+ yield f"Toddler tantrum (Stateless): {str(e)}"
 
75
 
76
  # Launching with Gradio 6.0 compatibility
77
  demo = gr.ChatInterface(