david167 commited on
Commit
2e7d584
·
1 Parent(s): 04a4f80

Fix CoT truncation: increase min_new_tokens to 1000, add generation logging, improve truncated JSON handling

Browse files
Files changed (1) hide show
  1. gradio_app.py +16 -2
gradio_app.py CHANGED
@@ -106,8 +106,8 @@ def generate_response(prompt, temperature=0.8):
106
 
107
  # Set minimum tokens based on request type
108
  if is_cot_request:
109
- min_tokens = 500 # Higher minimum for CoT to ensure complete responses
110
- logger.info("Detected Chain of Thinking request - using min_new_tokens=500")
111
  else:
112
  min_tokens = 200 # Standard minimum
113
 
@@ -147,6 +147,15 @@ def generate_response(prompt, temperature=0.8):
147
  # Decode the response
148
  generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
149
 
 
 
 
 
 
 
 
 
 
150
  # Post-decode guard: if a top-level JSON array closes, trim to the first full array
151
  # This helps prevent trailing prose like 'assistant' or 'Message'.
152
  try:
@@ -194,6 +203,11 @@ def generate_response(prompt, temperature=0.8):
194
  json_text = generated_text[start_idx:end_idx+1]
195
  logger.info(f"Extracted complete JSON array of length {len(json_text)}")
196
  generated_text = json_text
 
 
 
 
 
197
  except Exception as e:
198
  logger.warning(f"Error in JSON extraction: {e}")
199
  pass
 
106
 
107
  # Set minimum tokens based on request type
108
  if is_cot_request:
109
+ min_tokens = 1000 # Much higher minimum for CoT to ensure complete responses
110
+ logger.info("Detected Chain of Thinking request - using min_new_tokens=1000")
111
  else:
112
  min_tokens = 200 # Standard minimum
113
 
 
147
  # Decode the response
148
  generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
149
 
150
+ # Log generation details for debugging
151
+ input_length = inputs['input_ids'].shape[1]
152
+ output_length = outputs[0].shape[0]
153
+ generated_length = output_length - input_length
154
+ logger.info(f"Generation stats - Input: {input_length} tokens, Generated: {generated_length} tokens, Min required: {min_tokens}")
155
+
156
+ if generated_length < min_tokens:
157
+ logger.warning(f"Generated {generated_length} tokens but minimum was {min_tokens} - response may be truncated")
158
+
159
  # Post-decode guard: if a top-level JSON array closes, trim to the first full array
160
  # This helps prevent trailing prose like 'assistant' or 'Message'.
161
  try:
 
203
  json_text = generated_text[start_idx:end_idx+1]
204
  logger.info(f"Extracted complete JSON array of length {len(json_text)}")
205
  generated_text = json_text
206
+ elif start_idx is not None:
207
+ # Found start but no end - response was truncated
208
+ logger.warning("JSON array started but never closed - response truncated")
209
+ # Try to extract what we have and let the client handle it
210
+ generated_text = generated_text[start_idx:]
211
  except Exception as e:
212
  logger.warning(f"Error in JSON extraction: {e}")
213
  pass