Spaces:
Sleeping
Sleeping
Aggressive fix for CoT truncation: increase min_new_tokens to 1500, suppress EOS token for CoT requests, cap max_new_tokens
Browse files- gradio_app.py +34 -18
gradio_app.py
CHANGED
|
@@ -106,8 +106,10 @@ def generate_response(prompt, temperature=0.8):
|
|
| 106 |
|
| 107 |
# Set minimum tokens based on request type
|
| 108 |
if is_cot_request:
|
| 109 |
-
min_tokens =
|
| 110 |
-
|
|
|
|
|
|
|
| 111 |
else:
|
| 112 |
min_tokens = 200 # Standard minimum
|
| 113 |
|
|
@@ -126,23 +128,37 @@ def generate_response(prompt, temperature=0.8):
|
|
| 126 |
|
| 127 |
# Generate response
|
| 128 |
with torch.no_grad():
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
**inputs,
|
| 131 |
-
max_new_tokens
|
| 132 |
-
temperature
|
| 133 |
-
top_p
|
| 134 |
-
do_sample
|
| 135 |
-
num_beams
|
| 136 |
-
pad_token_id
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
# Decode the response
|
| 148 |
generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
| 106 |
|
| 107 |
# Set minimum tokens based on request type
|
| 108 |
if is_cot_request:
|
| 109 |
+
min_tokens = 1500 # Even higher minimum for CoT to ensure complete responses
|
| 110 |
+
# Also reduce max_new_tokens to ensure we don't hit context limits
|
| 111 |
+
gen_max_new_tokens = min(gen_max_new_tokens, 2048) # Cap at 2048 for CoT
|
| 112 |
+
logger.info(f"Detected Chain of Thinking request - using min_new_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}")
|
| 113 |
else:
|
| 114 |
min_tokens = 200 # Standard minimum
|
| 115 |
|
|
|
|
| 128 |
|
| 129 |
# Generate response
|
| 130 |
with torch.no_grad():
|
| 131 |
+
# For CoT requests, be more aggressive about preventing early stopping
|
| 132 |
+
if is_cot_request:
|
| 133 |
+
# Suppress EOS token for CoT to prevent early termination
|
| 134 |
+
eos_token_id = None
|
| 135 |
+
suppress_tokens = [model_manager.tokenizer.eos_token_id] if model_manager.tokenizer.eos_token_id is not None else None
|
| 136 |
+
else:
|
| 137 |
+
eos_token_id = model_manager.tokenizer.eos_token_id
|
| 138 |
+
suppress_tokens = None
|
| 139 |
+
|
| 140 |
+
generation_kwargs = {
|
| 141 |
**inputs,
|
| 142 |
+
"max_new_tokens": gen_max_new_tokens,
|
| 143 |
+
"temperature": temperature,
|
| 144 |
+
"top_p": 0.95,
|
| 145 |
+
"do_sample": True,
|
| 146 |
+
"num_beams": 1,
|
| 147 |
+
"pad_token_id": model_manager.tokenizer.eos_token_id,
|
| 148 |
+
"early_stopping": False,
|
| 149 |
+
"repetition_penalty": 1.05,
|
| 150 |
+
"no_repeat_ngram_size": 0,
|
| 151 |
+
"length_penalty": 1.0,
|
| 152 |
+
"min_new_tokens": min_tokens
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
# Add EOS suppression for CoT
|
| 156 |
+
if eos_token_id is not None:
|
| 157 |
+
generation_kwargs["eos_token_id"] = eos_token_id
|
| 158 |
+
if suppress_tokens is not None:
|
| 159 |
+
generation_kwargs["suppress_tokens"] = suppress_tokens
|
| 160 |
+
|
| 161 |
+
outputs = model_manager.model.generate(**generation_kwargs)
|
| 162 |
|
| 163 |
# Decode the response
|
| 164 |
generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
|