Spaces:
Sleeping
Sleeping
MAXIMUM TOKEN SETTINGS: Use 131k context, 16k max_new_tokens, 2k min_tokens for CoT - eliminate all truncation
Browse files- gradio_app.py +34 -49
gradio_app.py
CHANGED
|
@@ -84,34 +84,33 @@ def generate_response(prompt, temperature=0.8):
|
|
| 84 |
|
| 85 |
"""
|
| 86 |
|
| 87 |
-
# Determine context window and
|
| 88 |
try:
|
| 89 |
-
max_ctx = getattr(model_manager.model.config, "max_position_embeddings",
|
| 90 |
except Exception:
|
| 91 |
-
max_ctx =
|
| 92 |
-
|
| 93 |
-
# Reserve room for generation; cap to half the context as a safety default
|
| 94 |
-
safe_max_new = min(8192, max(max_ctx // 2, 256))
|
| 95 |
-
# If caller requested temperature, keep; we control new tokens internally
|
| 96 |
-
gen_max_new_tokens = min(safe_max_new, 8192)
|
| 97 |
-
|
| 98 |
-
# Allowed input tokens is context minus generation budget and a small buffer
|
| 99 |
-
allowed_input_tokens = max(512, max_ctx - gen_max_new_tokens - 64)
|
| 100 |
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
is_cot_request = ("chain-of-thinking" in prompt.lower() or
|
| 103 |
"chain of thinking" in prompt.lower() or
|
| 104 |
"Return exactly this JSON array" in prompt or
|
| 105 |
("verbatim" in prompt.lower() and "json array" in prompt.lower()))
|
| 106 |
|
| 107 |
-
#
|
| 108 |
if is_cot_request:
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
| 113 |
else:
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
# Tokenize the input with safe truncation
|
| 117 |
inputs = model_manager.tokenizer(
|
|
@@ -126,39 +125,25 @@ def generate_response(prompt, temperature=0.8):
|
|
| 126 |
model_device = next(model_manager.model.parameters()).device
|
| 127 |
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 128 |
|
| 129 |
-
# Generate response
|
| 130 |
with torch.no_grad():
|
| 131 |
-
|
| 132 |
-
if is_cot_request:
|
| 133 |
-
# Suppress EOS token for CoT to prevent early termination
|
| 134 |
-
eos_token_id = None
|
| 135 |
-
suppress_tokens = [model_manager.tokenizer.eos_token_id] if model_manager.tokenizer.eos_token_id is not None else None
|
| 136 |
-
else:
|
| 137 |
-
eos_token_id = model_manager.tokenizer.eos_token_id
|
| 138 |
-
suppress_tokens = None
|
| 139 |
-
|
| 140 |
-
generation_kwargs = {
|
| 141 |
**inputs,
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
generation_kwargs["eos_token_id"] = eos_token_id
|
| 158 |
-
if suppress_tokens is not None:
|
| 159 |
-
generation_kwargs["suppress_tokens"] = suppress_tokens
|
| 160 |
-
|
| 161 |
-
outputs = model_manager.model.generate(**generation_kwargs)
|
| 162 |
|
| 163 |
# Decode the response
|
| 164 |
generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
| 84 |
|
| 85 |
"""
|
| 86 |
|
| 87 |
+
# Determine context window and USE ABSOLUTE MAXIMUM
|
| 88 |
try:
|
| 89 |
+
max_ctx = getattr(model_manager.model.config, "max_position_embeddings", 131072) # Llama 3.1 supports up to 131k
|
| 90 |
except Exception:
|
| 91 |
+
max_ctx = 131072 # Use maximum possible
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
logger.info(f"Model max context: {max_ctx} tokens")
|
| 94 |
+
|
| 95 |
+
# Detect if this is a Chain of Thinking request
|
| 96 |
is_cot_request = ("chain-of-thinking" in prompt.lower() or
|
| 97 |
"chain of thinking" in prompt.lower() or
|
| 98 |
"Return exactly this JSON array" in prompt or
|
| 99 |
("verbatim" in prompt.lower() and "json array" in prompt.lower()))
|
| 100 |
|
| 101 |
+
# MAXIMIZE GENERATION TOKENS - use most of context for generation
|
| 102 |
if is_cot_request:
|
| 103 |
+
# For CoT, use MAXIMUM possible generation tokens
|
| 104 |
+
gen_max_new_tokens = 16384 # Very high limit for complete responses
|
| 105 |
+
min_tokens = 2000 # High minimum to force complete generation
|
| 106 |
+
# Allow most of context for input
|
| 107 |
+
allowed_input_tokens = max_ctx - gen_max_new_tokens - 100 # Small safety buffer
|
| 108 |
+
logger.info(f"CoT REQUEST - MAXIMIZED: min_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}, input_limit={allowed_input_tokens}")
|
| 109 |
else:
|
| 110 |
+
# Standard requests
|
| 111 |
+
gen_max_new_tokens = 8192
|
| 112 |
+
min_tokens = 200
|
| 113 |
+
allowed_input_tokens = max_ctx - gen_max_new_tokens - 100
|
| 114 |
|
| 115 |
# Tokenize the input with safe truncation
|
| 116 |
inputs = model_manager.tokenizer(
|
|
|
|
| 125 |
model_device = next(model_manager.model.parameters()).device
|
| 126 |
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 127 |
|
| 128 |
+
# Generate response with MAXIMUM settings
|
| 129 |
with torch.no_grad():
|
| 130 |
+
outputs = model_manager.model.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
**inputs,
|
| 132 |
+
max_new_tokens=gen_max_new_tokens,
|
| 133 |
+
min_new_tokens=min_tokens,
|
| 134 |
+
temperature=temperature,
|
| 135 |
+
top_p=0.95,
|
| 136 |
+
do_sample=True,
|
| 137 |
+
num_beams=1,
|
| 138 |
+
pad_token_id=model_manager.tokenizer.eos_token_id,
|
| 139 |
+
eos_token_id=model_manager.tokenizer.eos_token_id,
|
| 140 |
+
early_stopping=False, # Never stop early
|
| 141 |
+
repetition_penalty=1.05,
|
| 142 |
+
no_repeat_ngram_size=0,
|
| 143 |
+
length_penalty=1.0,
|
| 144 |
+
# Force generation to continue
|
| 145 |
+
use_cache=True
|
| 146 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
# Decode the response
|
| 149 |
generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
|