david167 commited on
Commit
14f445d
·
1 Parent(s): b394386

MAXIMUM TOKEN SETTINGS: Use 131k context, 16k max_new_tokens, 2k min_tokens for CoT - eliminate all truncation

Browse files
Files changed (1) hide show
  1. gradio_app.py +34 -49
gradio_app.py CHANGED
@@ -84,34 +84,33 @@ def generate_response(prompt, temperature=0.8):
84
 
85
  """
86
 
87
- # Determine context window and allocate space for input vs. generation
88
  try:
89
- max_ctx = getattr(model_manager.model.config, "max_position_embeddings", 8192)
90
  except Exception:
91
- max_ctx = 8192
92
-
93
- # Reserve room for generation; cap to half the context as a safety default
94
- safe_max_new = min(8192, max(max_ctx // 2, 256))
95
- # If caller requested temperature, keep; we control new tokens internally
96
- gen_max_new_tokens = min(safe_max_new, 8192)
97
-
98
- # Allowed input tokens is context minus generation budget and a small buffer
99
- allowed_input_tokens = max(512, max_ctx - gen_max_new_tokens - 64)
100
 
101
- # Detect if this is a Chain of Thinking request and adjust min_new_tokens
 
 
102
  is_cot_request = ("chain-of-thinking" in prompt.lower() or
103
  "chain of thinking" in prompt.lower() or
104
  "Return exactly this JSON array" in prompt or
105
  ("verbatim" in prompt.lower() and "json array" in prompt.lower()))
106
 
107
- # Set minimum tokens based on request type
108
  if is_cot_request:
109
- min_tokens = 1500 # Even higher minimum for CoT to ensure complete responses
110
- # Also reduce max_new_tokens to ensure we don't hit context limits
111
- gen_max_new_tokens = min(gen_max_new_tokens, 2048) # Cap at 2048 for CoT
112
- logger.info(f"Detected Chain of Thinking request - using min_new_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}")
 
 
113
  else:
114
- min_tokens = 200 # Standard minimum
 
 
 
115
 
116
  # Tokenize the input with safe truncation
117
  inputs = model_manager.tokenizer(
@@ -126,39 +125,25 @@ def generate_response(prompt, temperature=0.8):
126
  model_device = next(model_manager.model.parameters()).device
127
  inputs = {k: v.to(model_device) for k, v in inputs.items()}
128
 
129
- # Generate response
130
  with torch.no_grad():
131
- # For CoT requests, be more aggressive about preventing early stopping
132
- if is_cot_request:
133
- # Suppress EOS token for CoT to prevent early termination
134
- eos_token_id = None
135
- suppress_tokens = [model_manager.tokenizer.eos_token_id] if model_manager.tokenizer.eos_token_id is not None else None
136
- else:
137
- eos_token_id = model_manager.tokenizer.eos_token_id
138
- suppress_tokens = None
139
-
140
- generation_kwargs = {
141
  **inputs,
142
- "max_new_tokens": gen_max_new_tokens,
143
- "temperature": temperature,
144
- "top_p": 0.95,
145
- "do_sample": True,
146
- "num_beams": 1,
147
- "pad_token_id": model_manager.tokenizer.eos_token_id,
148
- "early_stopping": False,
149
- "repetition_penalty": 1.05,
150
- "no_repeat_ngram_size": 0,
151
- "length_penalty": 1.0,
152
- "min_new_tokens": min_tokens
153
- }
154
-
155
- # Add EOS suppression for CoT
156
- if eos_token_id is not None:
157
- generation_kwargs["eos_token_id"] = eos_token_id
158
- if suppress_tokens is not None:
159
- generation_kwargs["suppress_tokens"] = suppress_tokens
160
-
161
- outputs = model_manager.model.generate(**generation_kwargs)
162
 
163
  # Decode the response
164
  generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
84
 
85
  """
86
 
87
+ # Determine context window and USE ABSOLUTE MAXIMUM
88
  try:
89
+ max_ctx = getattr(model_manager.model.config, "max_position_embeddings", 131072) # Llama 3.1 supports up to 131k
90
  except Exception:
91
+ max_ctx = 131072 # Use maximum possible
 
 
 
 
 
 
 
 
92
 
93
+ logger.info(f"Model max context: {max_ctx} tokens")
94
+
95
+ # Detect if this is a Chain of Thinking request
96
  is_cot_request = ("chain-of-thinking" in prompt.lower() or
97
  "chain of thinking" in prompt.lower() or
98
  "Return exactly this JSON array" in prompt or
99
  ("verbatim" in prompt.lower() and "json array" in prompt.lower()))
100
 
101
+ # MAXIMIZE GENERATION TOKENS - use most of context for generation
102
  if is_cot_request:
103
+ # For CoT, use MAXIMUM possible generation tokens
104
+ gen_max_new_tokens = 16384 # Very high limit for complete responses
105
+ min_tokens = 2000 # High minimum to force complete generation
106
+ # Allow most of context for input
107
+ allowed_input_tokens = max_ctx - gen_max_new_tokens - 100 # Small safety buffer
108
+ logger.info(f"CoT REQUEST - MAXIMIZED: min_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}, input_limit={allowed_input_tokens}")
109
  else:
110
+ # Standard requests
111
+ gen_max_new_tokens = 8192
112
+ min_tokens = 200
113
+ allowed_input_tokens = max_ctx - gen_max_new_tokens - 100
114
 
115
  # Tokenize the input with safe truncation
116
  inputs = model_manager.tokenizer(
 
125
  model_device = next(model_manager.model.parameters()).device
126
  inputs = {k: v.to(model_device) for k, v in inputs.items()}
127
 
128
+ # Generate response with MAXIMUM settings
129
  with torch.no_grad():
130
+ outputs = model_manager.model.generate(
 
 
 
 
 
 
 
 
 
131
  **inputs,
132
+ max_new_tokens=gen_max_new_tokens,
133
+ min_new_tokens=min_tokens,
134
+ temperature=temperature,
135
+ top_p=0.95,
136
+ do_sample=True,
137
+ num_beams=1,
138
+ pad_token_id=model_manager.tokenizer.eos_token_id,
139
+ eos_token_id=model_manager.tokenizer.eos_token_id,
140
+ early_stopping=False, # Never stop early
141
+ repetition_penalty=1.05,
142
+ no_repeat_ngram_size=0,
143
+ length_penalty=1.0,
144
+ # Force generation to continue
145
+ use_cache=True
146
+ )
 
 
 
 
 
147
 
148
  # Decode the response
149
  generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)