david167 commited on
Commit
b394386
·
1 Parent(s): 2e7d584

Aggressive fix for CoT truncation: increase min_new_tokens to 1500, suppress EOS token for CoT requests, cap max_new_tokens

Browse files
Files changed (1) hide show
  1. gradio_app.py +34 -18
gradio_app.py CHANGED
@@ -106,8 +106,10 @@ def generate_response(prompt, temperature=0.8):
106
 
107
  # Set minimum tokens based on request type
108
  if is_cot_request:
109
- min_tokens = 1000 # Much higher minimum for CoT to ensure complete responses
110
- logger.info("Detected Chain of Thinking request - using min_new_tokens=1000")
 
 
111
  else:
112
  min_tokens = 200 # Standard minimum
113
 
@@ -126,23 +128,37 @@ def generate_response(prompt, temperature=0.8):
126
 
127
  # Generate response
128
  with torch.no_grad():
129
- outputs = model_manager.model.generate(
 
 
 
 
 
 
 
 
 
130
  **inputs,
131
- max_new_tokens=gen_max_new_tokens,
132
- temperature=temperature,
133
- top_p=0.95,
134
- do_sample=True,
135
- num_beams=1,
136
- pad_token_id=model_manager.tokenizer.eos_token_id,
137
- # Keep EOS but rely primarily on post-decode stop to capture full JSON
138
- eos_token_id=model_manager.tokenizer.eos_token_id,
139
- early_stopping=False,
140
- repetition_penalty=1.05,
141
- no_repeat_ngram_size=0,
142
- length_penalty=1.0,
143
- # Dynamic minimum based on request type
144
- min_new_tokens=min_tokens
145
- )
 
 
 
 
 
146
 
147
  # Decode the response
148
  generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
106
 
107
  # Set minimum tokens based on request type
108
  if is_cot_request:
109
+ min_tokens = 1500 # Even higher minimum for CoT to ensure complete responses
110
+ # Also reduce max_new_tokens to ensure we don't hit context limits
111
+ gen_max_new_tokens = min(gen_max_new_tokens, 2048) # Cap at 2048 for CoT
112
+ logger.info(f"Detected Chain of Thinking request - using min_new_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}")
113
  else:
114
  min_tokens = 200 # Standard minimum
115
 
 
128
 
129
  # Generate response
130
  with torch.no_grad():
131
+ # For CoT requests, be more aggressive about preventing early stopping
132
+ if is_cot_request:
133
+ # Suppress EOS token for CoT to prevent early termination
134
+ eos_token_id = None
135
+ suppress_tokens = [model_manager.tokenizer.eos_token_id] if model_manager.tokenizer.eos_token_id is not None else None
136
+ else:
137
+ eos_token_id = model_manager.tokenizer.eos_token_id
138
+ suppress_tokens = None
139
+
140
+ generation_kwargs = {
141
  **inputs,
142
+ "max_new_tokens": gen_max_new_tokens,
143
+ "temperature": temperature,
144
+ "top_p": 0.95,
145
+ "do_sample": True,
146
+ "num_beams": 1,
147
+ "pad_token_id": model_manager.tokenizer.eos_token_id,
148
+ "early_stopping": False,
149
+ "repetition_penalty": 1.05,
150
+ "no_repeat_ngram_size": 0,
151
+ "length_penalty": 1.0,
152
+ "min_new_tokens": min_tokens
153
+ }
154
+
155
+ # Add EOS suppression for CoT
156
+ if eos_token_id is not None:
157
+ generation_kwargs["eos_token_id"] = eos_token_id
158
+ if suppress_tokens is not None:
159
+ generation_kwargs["suppress_tokens"] = suppress_tokens
160
+
161
+ outputs = model_manager.model.generate(**generation_kwargs)
162
 
163
  # Decode the response
164
  generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)