david167 commited on
Commit
4185c2a
·
1 Parent(s): 0d85e38

Fix response truncation: disable early stopping, increase token limits to 4096, add debugging logs

Browse files
Files changed (1) hide show
  1. gradio_app.py +10 -3
gradio_app.py CHANGED
@@ -100,7 +100,7 @@ def chat_with_model(message, history, temperature):
100
  """
101
 
102
  # Generate response using the model directly
103
- inputs = model_manager.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
104
 
105
  # Force all inputs to the same device as the model
106
  if model_manager.device == "cuda:0":
@@ -110,13 +110,15 @@ def chat_with_model(message, history, temperature):
110
  with torch.no_grad():
111
  outputs = model_manager.model.generate(
112
  **inputs,
113
- max_new_tokens=2048,
114
  temperature=temperature,
115
  top_p=0.95,
116
  do_sample=True,
117
  num_beams=1,
118
  pad_token_id=model_manager.tokenizer.eos_token_id,
119
- early_stopping=True
 
 
120
  )
121
 
122
  # Decode the generated text and remove the input prompt
@@ -128,6 +130,7 @@ def chat_with_model(message, history, temperature):
128
  # Find the position after the assistant header
129
  response_start = full_text.find(assistant_start) + len(assistant_start)
130
  response = full_text[response_start:].strip()
 
131
  else:
132
  # Fallback: try to remove the original prompt
133
  try:
@@ -135,6 +138,10 @@ def chat_with_model(message, history, temperature):
135
  except:
136
  response = full_text.strip()
137
 
 
 
 
 
138
  if not response:
139
  response = "I couldn't generate a response. Please try a different prompt."
140
 
 
100
  """
101
 
102
  # Generate response using the model directly
103
+ inputs = model_manager.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
104
 
105
  # Force all inputs to the same device as the model
106
  if model_manager.device == "cuda:0":
 
110
  with torch.no_grad():
111
  outputs = model_manager.model.generate(
112
  **inputs,
113
+ max_new_tokens=4096,
114
  temperature=temperature,
115
  top_p=0.95,
116
  do_sample=True,
117
  num_beams=1,
118
  pad_token_id=model_manager.tokenizer.eos_token_id,
119
+ eos_token_id=model_manager.tokenizer.eos_token_id,
120
+ early_stopping=False, # Disable early stopping to prevent premature truncation
121
+ repetition_penalty=1.1 # Add slight repetition penalty to improve quality
122
  )
123
 
124
  # Decode the generated text and remove the input prompt
 
130
  # Find the position after the assistant header
131
  response_start = full_text.find(assistant_start) + len(assistant_start)
132
  response = full_text[response_start:].strip()
133
+ logger.info(f"Extracted response length: {len(response)}")
134
  else:
135
  # Fallback: try to remove the original prompt
136
  try:
 
138
  except:
139
  response = full_text.strip()
140
 
141
+ # Check if response ends abruptly (might indicate truncation)
142
+ if response and not response.endswith(('.', '!', '?', ':', ';')):
143
+ logger.warning(f"Response may be truncated - ends with: '{response[-20:]}'")
144
+
145
  if not response:
146
  response = "I couldn't generate a response. Please try a different prompt."
147