david167 commited on
Commit
342694d
·
1 Parent(s): 4ad994e

Improve model generation parameters and add logging - fix response truncation issues

Browse files
Files changed (1) hide show
  1. gradio_app.py +15 -3
gradio_app.py CHANGED
@@ -224,15 +224,18 @@ def chat_with_model(message, history, temperature, json_mode=False, json_templat
224
  with torch.no_grad():
225
  outputs = model_manager.model.generate(
226
  **inputs,
227
- max_new_tokens=4096,
228
  temperature=temperature,
229
  top_p=0.95,
230
  do_sample=True,
231
  num_beams=1,
232
  pad_token_id=model_manager.tokenizer.eos_token_id,
233
  eos_token_id=model_manager.tokenizer.eos_token_id,
234
- early_stopping=False, # Disable early stopping to prevent premature truncation
235
- repetition_penalty=1.1 # Add slight repetition penalty to improve quality
 
 
 
236
  )
237
 
238
  # Decode response
@@ -245,9 +248,18 @@ def chat_with_model(message, history, temperature, json_mode=False, json_templat
245
  # Fallback: try to remove the prompt by length
246
  response = generated_text[len(prompt):].strip()
247
 
 
 
 
 
248
  # Process JSON response if in JSON mode
249
  if json_mode and response:
 
250
  response = prettify_json_response(response)
 
 
 
 
251
 
252
  # Add to history
253
  history.append({"role": "user", "content": message})
 
224
  with torch.no_grad():
225
  outputs = model_manager.model.generate(
226
  **inputs,
227
+ max_new_tokens=2048, # Reduced but sufficient for JSON responses
228
  temperature=temperature,
229
  top_p=0.95,
230
  do_sample=True,
231
  num_beams=1,
232
  pad_token_id=model_manager.tokenizer.eos_token_id,
233
  eos_token_id=model_manager.tokenizer.eos_token_id,
234
+ early_stopping=False, # Disable early stopping
235
+ repetition_penalty=1.05, # Lighter repetition penalty
236
+ no_repeat_ngram_size=0, # Disable n-gram repetition blocking
237
+ length_penalty=1.0, # Neutral length penalty
238
+ min_new_tokens=50 # Ensure minimum response length
239
  )
240
 
241
  # Decode response
 
248
  # Fallback: try to remove the prompt by length
249
  response = generated_text[len(prompt):].strip()
250
 
251
+ # Log response length for debugging
252
+ logger.info(f"Generated response length: {len(response)} characters")
253
+ logger.info(f"Response preview: {response[:200]}...")
254
+
255
  # Process JSON response if in JSON mode
256
  if json_mode and response:
257
+ original_response = response
258
  response = prettify_json_response(response)
259
+ if response != original_response:
260
+ logger.info(f"JSON processing applied. New length: {len(response)}")
261
+ else:
262
+ logger.info("JSON processing had no effect - no valid JSON found")
263
 
264
  # Add to history
265
  history.append({"role": "user", "content": message})