ivxivx commited on
Commit
2dca75f
·
unverified ·
1 Parent(s): cb75d9f

chore: parse assistant message

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -77,20 +77,24 @@ def predict(message, history):
77
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
78
  # 3. Generate response
79
  outputs = model.generate(**inputs, max_new_tokens=100)
80
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
81
 
82
- # print(f"Response: {response}, outputs: {outputs}")
 
83
 
84
  # Extract only the assistant's message (after the last user message)
85
- # This works for most chat templates that append the assistant's reply at the end
86
- if "<|im_start|>assistant" in decoded:
 
 
 
87
  response = decoded.split("<|im_start|>assistant")[-1]
88
- # Remove possible end tokens or markers
89
  response = response.replace("<|im_end|>", "").strip()
90
  else:
91
  # Fallback: just return the decoded output
92
  response = decoded.strip()
93
-
94
  return response
95
 
96
  demo = gr.ChatInterface(predict, type="messages", examples=examples)
 
77
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
78
  # 3. Generate response
79
  outputs = model.generate(**inputs, max_new_tokens=100)
80
+ # skip_special_tokens=False: we want to keep special tokens for easier parsing
81
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
82
 
83
+ # print(f"decoded: {decoded}\n")
84
+ # print(f"outputs: {outputs}\n")
85
 
86
  # Extract only the assistant's message (after the last user message)
87
+ if "<|start_header_id|>assistant<|end_header_id|>" in decoded:
88
+ response = decoded.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
89
+ response = response.replace("<|eot_id|>", "").strip()
90
+ elif "<|im_start|>assistant" in decoded:
91
+ # This works for most chat templates that append the assistant's reply at the end
92
  response = decoded.split("<|im_start|>assistant")[-1]
 
93
  response = response.replace("<|im_end|>", "").strip()
94
  else:
95
  # Fallback: just return the decoded output
96
  response = decoded.strip()
97
+
98
  return response
99
 
100
  demo = gr.ChatInterface(predict, type="messages", examples=examples)