suraj-self commited on
Commit
a5f28af
·
1 Parent(s): 6383c22
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -38,32 +38,32 @@ def predict(message, history):
38
  # 1. Prepare token list
39
  tokens = [tokenizer.bos_token_id]
40
 
41
- # FIX: Robust history handling for Gradio 5/6
42
  for entry in history:
43
- # Handle list of dicts format: {"role": "user", "content": "..."}
44
  if isinstance(entry, dict):
 
45
  role = entry.get("role")
46
  content = entry.get("content", "")
47
  if role == "user":
48
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(content) + [tokenizer.user_end_id])
49
  elif role == "assistant":
50
- tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(content) + [tokenizer.assistant_end_id])
51
 
52
- # Handle old list of lists format: [user_msg, assistant_msg]
53
  elif isinstance(entry, (list, tuple)):
54
- human, assistant = entry[0], entry[1]
55
- if human:
56
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(human) + [tokenizer.user_end_id])
57
- if assistant:
58
- tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant) + [tokenizer.assistant_end_id])
 
59
 
60
- # Add current user prompt
61
- tokens.extend([tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id])
62
  tokens.append(tokenizer.assistant_start_id)
63
 
64
- # 2. Streaming Generation
65
  with torch.no_grad():
66
- # Pass as list to satisfy the nanochat assertion
67
  output = model.generate(
68
  tokens,
69
  max_tokens=512,
@@ -76,15 +76,16 @@ def predict(message, history):
76
  token_id = token if isinstance(token, int) else token.item()
77
  char = tokenizer.decode([token_id])
78
 
79
- # Stop if we hit the assistant end tag
80
- if "<|assistant_end|>" in char or "<|end|>" in char:
81
  break
82
 
83
  generated_text += char
84
  yield generated_text.strip()
85
 
86
  except Exception as e:
87
- print(f"CRITICAL ERROR: {e}")
 
88
  yield f"Toddler tantrum: {str(e)}"
89
 
90
  # Launching with Gradio 6.0 compatibility
 
38
  # 1. Prepare token list
39
  tokens = [tokenizer.bos_token_id]
40
 
41
+ # FIX: Explicitly extract 'content' string from Gradio objects
42
  for entry in history:
 
43
  if isinstance(entry, dict):
44
+ # Gradio 5/6 format: {"role": "user", "content": "..."}
45
  role = entry.get("role")
46
  content = entry.get("content", "")
47
  if role == "user":
48
+ tokens.extend([tokenizer.user_start_id] + tokenizer.encode(str(content)) + [tokenizer.user_end_id])
49
  elif role == "assistant":
50
+ tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(str(content)) + [tokenizer.assistant_end_id])
51
 
 
52
  elif isinstance(entry, (list, tuple)):
53
+ # Legacy format: [user_msg, assistant_msg]
54
+ user_content, assistant_content = entry[0], entry[1]
55
+ if user_content:
56
+ tokens.extend([tokenizer.user_start_id] + tokenizer.encode(str(user_content)) + [tokenizer.user_end_id])
57
+ if assistant_content:
58
+ tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(str(assistant_content)) + [tokenizer.assistant_end_id])
59
 
60
+ # 2. Add current user prompt
61
+ tokens.extend([tokenizer.user_start_id] + tokenizer.encode(str(message)) + [tokenizer.user_end_id])
62
  tokens.append(tokenizer.assistant_start_id)
63
 
64
+ # 3. Streaming Generation
65
  with torch.no_grad():
66
+ # Pass as list to satisfy the nanochat assertion check
67
  output = model.generate(
68
  tokens,
69
  max_tokens=512,
 
76
  token_id = token if isinstance(token, int) else token.item()
77
  char = tokenizer.decode([token_id])
78
 
79
+ # Stop tags to prevent the model from talking to itself
80
+ if any(tag in char for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]):
81
  break
82
 
83
  generated_text += char
84
  yield generated_text.strip()
85
 
86
  except Exception as e:
87
+ # Log the exact error to the console for QA debugging
88
+ print(f"Error details: {str(e)}")
89
  yield f"Toddler tantrum: {str(e)}"
90
 
91
  # Launching with Gradio 6.0 compatibility