david167 commited on
Commit
04a4f80
·
1 Parent(s): 678e0f9

Improve Chain of Thinking support: increase min_new_tokens to 500 for CoT requests, improve JSON bracket tracking for nested objects

Browse files
Files changed (1) hide show
  1. gradio_app.py +87 -7
gradio_app.py CHANGED
@@ -84,12 +84,39 @@ def generate_response(prompt, temperature=0.8):
84
 
85
  """
86
 
87
- # Tokenize the input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  inputs = model_manager.tokenizer(
89
- formatted_prompt,
90
- return_tensors="pt",
91
- truncation=True,
92
- max_length=4096
93
  )
94
 
95
  # Move inputs to the same device as the model
@@ -101,23 +128,76 @@ def generate_response(prompt, temperature=0.8):
101
  with torch.no_grad():
102
  outputs = model_manager.model.generate(
103
  **inputs,
104
- max_new_tokens=8192,
105
  temperature=temperature,
106
  top_p=0.95,
107
  do_sample=True,
108
  num_beams=1,
109
  pad_token_id=model_manager.tokenizer.eos_token_id,
 
110
  eos_token_id=model_manager.tokenizer.eos_token_id,
111
  early_stopping=False,
112
  repetition_penalty=1.05,
113
  no_repeat_ngram_size=0,
114
  length_penalty=1.0,
115
- min_new_tokens=50
 
116
  )
117
 
118
  # Decode the response
119
  generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # Extract just the assistant's response
122
  if "<|start_header_id|>assistant<|end_header_id|>" in generated_text:
123
  response = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
 
84
 
85
  """
86
 
87
+ # Determine context window and allocate space for input vs. generation
88
+ try:
89
+ max_ctx = getattr(model_manager.model.config, "max_position_embeddings", 8192)
90
+ except Exception:
91
+ max_ctx = 8192
92
+
93
+ # Reserve room for generation; cap to half the context as a safety default
94
+ safe_max_new = min(8192, max(max_ctx // 2, 256))
95
+ # If caller requested temperature, keep; we control new tokens internally
96
+ gen_max_new_tokens = min(safe_max_new, 8192)
97
+
98
+ # Allowed input tokens is context minus generation budget and a small buffer
99
+ allowed_input_tokens = max(512, max_ctx - gen_max_new_tokens - 64)
100
+
101
+ # Detect if this is a Chain of Thinking request and adjust min_new_tokens
102
+ is_cot_request = ("chain-of-thinking" in prompt.lower() or
103
+ "chain of thinking" in prompt.lower() or
104
+ "Return exactly this JSON array" in prompt or
105
+ ("verbatim" in prompt.lower() and "json array" in prompt.lower()))
106
+
107
+ # Set minimum tokens based on request type
108
+ if is_cot_request:
109
+ min_tokens = 500 # Higher minimum for CoT to ensure complete responses
110
+ logger.info("Detected Chain of Thinking request - using min_new_tokens=500")
111
+ else:
112
+ min_tokens = 200 # Standard minimum
113
+
114
+ # Tokenize the input with safe truncation
115
  inputs = model_manager.tokenizer(
116
+ formatted_prompt,
117
+ return_tensors="pt",
118
+ truncation=True,
119
+ max_length=allowed_input_tokens
120
  )
121
 
122
  # Move inputs to the same device as the model
 
128
  with torch.no_grad():
129
  outputs = model_manager.model.generate(
130
  **inputs,
131
+ max_new_tokens=gen_max_new_tokens,
132
  temperature=temperature,
133
  top_p=0.95,
134
  do_sample=True,
135
  num_beams=1,
136
  pad_token_id=model_manager.tokenizer.eos_token_id,
137
+ # Keep EOS but rely primarily on post-decode stop to capture full JSON
138
  eos_token_id=model_manager.tokenizer.eos_token_id,
139
  early_stopping=False,
140
  repetition_penalty=1.05,
141
  no_repeat_ngram_size=0,
142
  length_penalty=1.0,
143
+ # Dynamic minimum based on request type
144
+ min_new_tokens=min_tokens
145
  )
146
 
147
  # Decode the response
148
  generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
149
 
150
+ # Post-decode guard: if a top-level JSON array closes, trim to the first full array
151
+ # This helps prevent trailing prose like 'assistant' or 'Message'.
152
+ try:
153
+ # Track both bracket and brace depth to find first complete JSON structure
154
+ bracket_depth = 0 # [ ]
155
+ brace_depth = 0 # { }
156
+ in_string = False
157
+ escape_next = False
158
+ start_idx = None
159
+ end_idx = None
160
+
161
+ for i, ch in enumerate(generated_text):
162
+ # Handle string escaping
163
+ if escape_next:
164
+ escape_next = False
165
+ continue
166
+
167
+ if ch == '\\':
168
+ escape_next = True
169
+ continue
170
+
171
+ # Track if we're inside a string
172
+ if ch == '"' and not escape_next:
173
+ in_string = not in_string
174
+ continue
175
+
176
+ # Only count brackets/braces outside of strings
177
+ if not in_string:
178
+ if ch == '[':
179
+ if bracket_depth == 0 and brace_depth == 0 and start_idx is None:
180
+ start_idx = i
181
+ bracket_depth += 1
182
+ elif ch == ']':
183
+ bracket_depth = max(0, bracket_depth - 1)
184
+ if bracket_depth == 0 and brace_depth == 0 and start_idx is not None:
185
+ end_idx = i
186
+ break
187
+ elif ch == '{':
188
+ brace_depth += 1
189
+ elif ch == '}':
190
+ brace_depth = max(0, brace_depth - 1)
191
+
192
+ if start_idx is not None and end_idx is not None and end_idx > start_idx:
193
+ # Extract just the complete JSON array
194
+ json_text = generated_text[start_idx:end_idx+1]
195
+ logger.info(f"Extracted complete JSON array of length {len(json_text)}")
196
+ generated_text = json_text
197
+ except Exception as e:
198
+ logger.warning(f"Error in JSON extraction: {e}")
199
+ pass
200
+
201
  # Extract just the assistant's response
202
  if "<|start_header_id|>assistant<|end_header_id|>" in generated_text:
203
  response = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()