timbmg commited on
Commit
97025d7
Β·
unverified Β·
1 Parent(s): 1db96ea

Add context length error handling and buffer management for prompt construction in LLM inference process

Browse files
Files changed (1) hide show
  1. app.py +243 -128
app.py CHANGED
@@ -45,7 +45,10 @@ st.set_page_config(
45
 
46
  # Constants
47
  MAX_CONTEXT_SIZE = 131072 # Default max context
48
- MAX_TOKENS_BUFFER = 0.9 # Use 90% of max tokens
 
 
 
49
 
50
 
51
  def _redact_secrets(text: str, secrets: list[str | None]) -> str:
@@ -68,6 +71,107 @@ def _safe_model_config_for_session(model_config: dict | None) -> dict | None:
68
  return safe
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def validate_urls(arxiv_url: str, github_url: str) -> tuple[bool, str]:
72
  """Validate input URLs."""
73
  if not arxiv_url:
@@ -210,76 +314,27 @@ def process_discrepancy_detection(
210
  state="running",
211
  )
212
 
213
- # Step 5: Calculate tokens and prepare prompt
214
  step_start = time.time()
215
  status.update(label="πŸ“ Preparing prompt...", state="running")
 
 
 
 
 
216
  try:
217
- # Use provided model config
218
- tokenizer_name = model_config["tokenizer"]
219
- max_context = model_config["max_context"]
220
-
221
- token_counter = TokenCounter(model=tokenizer_name)
222
-
223
- # Calculate tokens for paper + prompt template
224
- prompt_template = Prompt("discrepancy_generation")
225
- intermediate_prompt = prompt_template(paper=paper_text, code="")
226
- tokens_intermediate_prompt = token_counter(intermediate_prompt)
227
-
228
- # Calculate remaining tokens for code
229
- max_total_tokens = int(max_context * MAX_TOKENS_BUFFER)
230
- remaining_code_tokens = max_total_tokens - tokens_intermediate_prompt
231
-
232
- logger.info(f"Tokens in intermediate prompt: {tokens_intermediate_prompt}")
233
- logger.info(f"Remaining tokens for code: {remaining_code_tokens}")
234
-
235
- # Get code prompt with token limit
236
- if code_loader:
237
- # Use CodeLoader for GitHub repos
238
- code_prompt = code_loader.get_code_prompt(
239
- token_counter=token_counter,
240
- max_tokens=remaining_code_tokens,
241
- )
242
- else:
243
- # Truncate code text to fit within token limit
244
- # Simple approach: count tokens as we add content
245
- code_prompt = ""
246
- code_tokens = 0
247
- code_lines = code_text.split('\n')
248
-
249
- for line in code_lines:
250
- line_with_newline = line + '\n'
251
- line_tokens = token_counter(line_with_newline)
252
- if code_tokens + line_tokens > remaining_code_tokens:
253
- logger.warning(f"Truncating code at {code_tokens} tokens (limit: {remaining_code_tokens})")
254
- break
255
- code_prompt += line_with_newline
256
- code_tokens += line_tokens
257
 
258
  results["code_prompt"] = code_prompt
259
-
260
- # Construct final prompt
261
- final_prompt = prompt_template(paper=paper_text, code=code_prompt)
262
  results["prompt"] = final_prompt
263
-
264
- final_tokens = token_counter(final_prompt)
265
- logger.info(f"Total tokens in final prompt: {final_tokens}")
266
-
267
- # Calculate max_tokens for completion (respecting model's context limit)
268
- # Leave some buffer for safety (use 95% of remaining context)
269
- max_context = model_config["max_context"]
270
- remaining_for_completion = max_context - final_tokens
271
-
272
- if remaining_for_completion <= 0:
273
- error_msg = f"Prompt too long: {final_tokens} tokens exceeds model's context limit of {max_context} tokens"
274
- logger.error(error_msg)
275
- results["error"] = error_msg
276
- status.update(label="❌ Prompt too long", state="error")
277
- return results
278
-
279
- # Use 95% of remaining to be safe, but ensure at least some tokens
280
- max_tokens_for_completion = max(1, int(remaining_for_completion * 0.95))
281
-
282
- logger.info(f"Max context: {max_context}, Input tokens: {final_tokens}, Remaining: {remaining_for_completion}, Max completion tokens: {max_tokens_for_completion}")
283
 
284
  step_time = time.time() - step_start
285
  step_timings["Prompt Preparation"] = step_time
@@ -295,75 +350,135 @@ def process_discrepancy_detection(
295
  status.update(label="❌ Error preparing prompt", state="error")
296
  return results
297
 
298
- # Step 6: Detect discrepancies with LLM
299
  step_start = time.time()
300
  status.update(label="πŸ€–\uFE0F Detecting discrepancies (this may take a while)...", state="running")
301
- try:
302
- # Extract model configuration
303
- model = model_config["model"]
304
- api_key = model_config.get("api_key")
305
- api_base = model_config.get("api_base")
306
- max_context = model_config.get("max_context")
307
-
308
- llm = LLM(
309
- model=model,
310
- api_key=api_key,
311
- api_base=api_base,
312
- temperature=1.0,
313
- top_p=1.0,
314
- reasoning_effort="high",
315
- max_context=max_context,
316
- max_tokens=max_tokens_for_completion, # Respect model's context limit
317
- )
318
-
319
- response = llm(final_prompt)
320
- results["llm_response"] = response
321
-
322
- # Extract content from response
323
- choices = response.get("choices", [])
324
- if not choices:
325
- raise ValueError("No choices in LLM response")
326
-
327
- content = (
328
- choices[0]
329
- .get("message", {})
330
- .get("content", "")
331
- )
332
-
333
- if not content:
334
- raise ValueError("Empty content in LLM response")
335
-
336
- # Parse discrepancies
337
- discrepancies = parse_discrepancies(content)
338
- results["discrepancies"] = discrepancies
339
-
340
- step_time = time.time() - step_start
341
- step_timings["LLM Inference"] = step_time
342
- total_time = sum(step_timings.values())
343
-
344
- st.write(f"βœ… LLM inference: {step_time:.1f}s")
345
- st.write("---")
346
- st.write(f"**Total time: {total_time:.1f}s**")
347
-
348
- if discrepancies:
349
- count = len(discrepancies)
350
- discrepancy_text = "discrepancy" if count == 1 else "discrepancies"
351
- status.update(
352
- label=f"βœ… Complete! Found {count} {discrepancy_text} ({total_time:.1f}s total)",
353
- state="complete",
354
  )
355
- else:
356
- status.update(
357
- label=f"βœ… Complete! No discrepancies found ({total_time:.1f}s total)",
358
- state="complete",
 
 
 
 
 
 
 
 
 
359
  )
360
-
361
- except Exception as e:
362
- api_key = model_config.get("api_key") if isinstance(model_config, dict) else None
363
- error_msg = f"Error during LLM inference: {_redact_secrets(str(e), [api_key])}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  logger.error(error_msg)
365
  results["error"] = error_msg
366
- status.update(label="❌ Error during inference", state="error")
367
  return results
368
 
369
  except Exception as e:
 
45
 
46
  # Constants
47
  MAX_CONTEXT_SIZE = 131072 # Default max context
48
+ MAX_TOKENS_BUFFER = 0.9 # Initial buffer (existing)
49
+ MIN_TOKENS_BUFFER = 0.5 # Minimum buffer before giving up
50
+ BUFFER_REDUCTION_STEP = 0.05 # How much to reduce each retry (5%)
51
+ MAX_BUFFER_RETRIES = 10 # Maximum retry attempts
52
 
53
 
54
  def _redact_secrets(text: str, secrets: list[str | None]) -> str:
 
71
  return safe
72
 
73
 
74
+ def _is_context_length_error(error_msg: str) -> bool:
75
+ """
76
+ Check if an error message indicates a context length error.
77
+
78
+ Args:
79
+ error_msg: The error message string
80
+
81
+ Returns:
82
+ True if it's a context length error, False otherwise
83
+ """
84
+ error_lower = error_msg.lower()
85
+ return (
86
+ "maximum context length" in error_lower
87
+ or "requested about" in error_lower
88
+ or ("context length is" in error_lower and "you requested" in error_lower)
89
+ )
90
+
91
+
92
+ def _build_prompt_with_buffer(
93
+ buffer_factor: float,
94
+ paper_text: str,
95
+ code_loader: CodeLoader | None,
96
+ code_text: str | None,
97
+ model_config: dict,
98
+ token_counter: TokenCounter,
99
+ ) -> tuple[str, str, int, int]:
100
+ """
101
+ Build prompt with a specific buffer factor.
102
+
103
+ Args:
104
+ buffer_factor: Buffer factor to use (e.g., 0.9 for 90%)
105
+ paper_text: The paper text
106
+ code_loader: CodeLoader instance (if using GitHub repo)
107
+ code_text: Raw code text (if using uploaded file)
108
+ model_config: Model configuration dictionary
109
+ token_counter: TokenCounter instance
110
+
111
+ Returns:
112
+ Tuple of (final_prompt, code_prompt, final_tokens, max_tokens_for_completion)
113
+ """
114
+ max_context = model_config["max_context"]
115
+
116
+ # Calculate tokens for paper + prompt template
117
+ prompt_template = Prompt("discrepancy_generation")
118
+ intermediate_prompt = prompt_template(paper=paper_text, code="")
119
+ tokens_intermediate_prompt = token_counter(intermediate_prompt)
120
+
121
+ # Calculate remaining tokens for code using the provided buffer factor
122
+ max_total_tokens = int(max_context * buffer_factor)
123
+ remaining_code_tokens = max_total_tokens - tokens_intermediate_prompt
124
+
125
+ logger.info(f"Tokens in intermediate prompt: {tokens_intermediate_prompt}")
126
+ logger.info(f"Remaining tokens for code (buffer {buffer_factor:.1%}): {remaining_code_tokens}")
127
+
128
+ # Get code prompt with token limit
129
+ if code_loader:
130
+ # Use CodeLoader for GitHub repos
131
+ code_prompt = code_loader.get_code_prompt(
132
+ token_counter=token_counter,
133
+ max_tokens=remaining_code_tokens,
134
+ )
135
+ else:
136
+ # Truncate code text to fit within token limit
137
+ code_prompt = ""
138
+ code_tokens = 0
139
+ code_lines = code_text.split('\n')
140
+
141
+ for line in code_lines:
142
+ line_with_newline = line + '\n'
143
+ line_tokens = token_counter(line_with_newline)
144
+ if code_tokens + line_tokens > remaining_code_tokens:
145
+ logger.warning(f"Truncating code at {code_tokens} tokens (limit: {remaining_code_tokens})")
146
+ break
147
+ code_prompt += line_with_newline
148
+ code_tokens += line_tokens
149
+
150
+ # Construct final prompt
151
+ final_prompt = prompt_template(paper=paper_text, code=code_prompt)
152
+ final_tokens = token_counter(final_prompt)
153
+ logger.info(f"Total tokens in final prompt: {final_tokens}")
154
+
155
+ # Calculate max_tokens for completion (respecting model's context limit)
156
+ # Leave some buffer for safety (use 95% of remaining context)
157
+ remaining_for_completion = max_context - final_tokens
158
+
159
+ if remaining_for_completion <= 0:
160
+ raise ValueError(
161
+ f"Prompt too long: {final_tokens} tokens exceeds model's context limit of {max_context} tokens"
162
+ )
163
+
164
+ # Use 95% of remaining to be safe, but ensure at least some tokens
165
+ max_tokens_for_completion = max(1, int(remaining_for_completion * 0.95))
166
+
167
+ logger.info(
168
+ f"Max context: {max_context}, Input tokens: {final_tokens}, "
169
+ f"Remaining: {remaining_for_completion}, Max completion tokens: {max_tokens_for_completion}"
170
+ )
171
+
172
+ return final_prompt, code_prompt, final_tokens, max_tokens_for_completion
173
+
174
+
175
  def validate_urls(arxiv_url: str, github_url: str) -> tuple[bool, str]:
176
  """Validate input URLs."""
177
  if not arxiv_url:
 
314
  state="running",
315
  )
316
 
317
+ # Step 5: Calculate tokens and prepare prompt (initial build)
318
  step_start = time.time()
319
  status.update(label="πŸ“ Preparing prompt...", state="running")
320
+
321
+ # Create token counter (needed for both Step 5 and Step 6 retry loop)
322
+ tokenizer_name = model_config["tokenizer"]
323
+ token_counter = TokenCounter(model=tokenizer_name)
324
+
325
  try:
326
+ # Build initial prompt with default buffer
327
+ final_prompt, code_prompt, final_tokens, max_tokens_for_completion = _build_prompt_with_buffer(
328
+ buffer_factor=MAX_TOKENS_BUFFER,
329
+ paper_text=paper_text,
330
+ code_loader=code_loader,
331
+ code_text=code_text,
332
+ model_config=model_config,
333
+ token_counter=token_counter,
334
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  results["code_prompt"] = code_prompt
 
 
 
337
  results["prompt"] = final_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  step_time = time.time() - step_start
340
  step_timings["Prompt Preparation"] = step_time
 
350
  status.update(label="❌ Error preparing prompt", state="error")
351
  return results
352
 
353
+ # Step 6: Detect discrepancies with LLM (with retry on context length errors)
354
  step_start = time.time()
355
  status.update(label="πŸ€–\uFE0F Detecting discrepancies (this may take a while)...", state="running")
356
+
357
+ # Retry configuration
358
+ initial_buffer = MAX_TOKENS_BUFFER # 0.9
359
+ min_buffer = MIN_TOKENS_BUFFER # 0.5
360
+ buffer_reduction_step = BUFFER_REDUCTION_STEP # 0.05
361
+ max_retries = MAX_BUFFER_RETRIES # 5
362
+
363
+ current_buffer = initial_buffer
364
+ retry_count = 0
365
+ success = False
366
+ current_final_prompt = final_prompt
367
+ current_max_tokens_for_completion = max_tokens_for_completion
368
+
369
+ while not success and current_buffer >= min_buffer and retry_count < max_retries:
370
+ try:
371
+ # Rebuild prompt with current buffer (if retry, otherwise use existing)
372
+ if retry_count > 0:
373
+ status.update(
374
+ label=f"πŸ”„ Retrying with reduced buffer ({current_buffer:.1%})...",
375
+ state="running"
376
+ )
377
+ st.write(f"πŸ”„ Retrying with reduced buffer ({current_buffer:.1%})...")
378
+
379
+ # Rebuild prompt with reduced buffer
380
+ current_final_prompt, code_prompt, final_tokens, current_max_tokens_for_completion = _build_prompt_with_buffer(
381
+ buffer_factor=current_buffer,
382
+ paper_text=paper_text,
383
+ code_loader=code_loader,
384
+ code_text=code_text,
385
+ model_config=model_config,
386
+ token_counter=token_counter,
387
+ )
388
+ results["code_prompt"] = code_prompt
389
+ results["prompt"] = current_final_prompt
390
+
391
+ # Extract model configuration
392
+ model = model_config["model"]
393
+ api_key = model_config.get("api_key")
394
+ api_base = model_config.get("api_base")
395
+ max_context = model_config.get("max_context")
396
+
397
+ llm = LLM(
398
+ model=model,
399
+ api_key=api_key,
400
+ api_base=api_base,
401
+ temperature=1.0,
402
+ top_p=1.0,
403
+ reasoning_effort="high",
404
+ max_context=max_context,
405
+ max_tokens=current_max_tokens_for_completion, # Respect model's context limit
 
 
 
406
  )
407
+
408
+ response = llm(current_final_prompt)
409
+ results["llm_response"] = response
410
+
411
+ # Extract content from response
412
+ choices = response.get("choices", [])
413
+ if not choices:
414
+ raise ValueError("No choices in LLM response")
415
+
416
+ content = (
417
+ choices[0]
418
+ .get("message", {})
419
+ .get("content", "")
420
  )
421
+
422
+ if not content:
423
+ raise ValueError("Empty content in LLM response")
424
+
425
+ # Parse discrepancies
426
+ discrepancies = parse_discrepancies(content)
427
+ results["discrepancies"] = discrepancies
428
+
429
+ step_time = time.time() - step_start
430
+ step_timings["LLM Inference"] = step_time
431
+ total_time = sum(step_timings.values())
432
+
433
+ st.write(f"βœ… LLM inference: {step_time:.1f}s")
434
+ st.write("---")
435
+ st.write(f"**Total time: {total_time:.1f}s**")
436
+
437
+ if discrepancies:
438
+ count = len(discrepancies)
439
+ discrepancy_text = "discrepancy" if count == 1 else "discrepancies"
440
+ status.update(
441
+ label=f"βœ… Complete! Found {count} {discrepancy_text} ({total_time:.1f}s total)",
442
+ state="complete",
443
+ )
444
+ else:
445
+ status.update(
446
+ label=f"βœ… Complete! No discrepancies found ({total_time:.1f}s total)",
447
+ state="complete",
448
+ )
449
+
450
+ success = True
451
+
452
+ except Exception as e:
453
+ error_msg = str(e)
454
+ api_key = model_config.get("api_key") if isinstance(model_config, dict) else None
455
+ redacted_error = _redact_secrets(error_msg, [api_key])
456
+
457
+ # Check if it's a context length error
458
+ if _is_context_length_error(error_msg) and current_buffer > min_buffer:
459
+ retry_count += 1
460
+ current_buffer -= buffer_reduction_step
461
+ logger.warning(
462
+ f"Context length error detected. Retrying with buffer {current_buffer:.1%} "
463
+ f"(attempt {retry_count}/{max_retries})"
464
+ )
465
+ continue # Retry with smaller buffer
466
+ else:
467
+ # Not a context length error, or we've exhausted retries
468
+ logger.error(f"Error during LLM inference: {redacted_error}")
469
+ results["error"] = f"Error during LLM inference: {redacted_error}"
470
+ status.update(label="❌ Error during inference", state="error")
471
+ return results
472
+
473
+ # If we exhausted retries or hit minimum buffer
474
+ if not success:
475
+ error_msg = (
476
+ f"Could not fit prompt within context limits after {retry_count} retries. "
477
+ f"Minimum buffer ({min_buffer:.1%}) reached."
478
+ )
479
  logger.error(error_msg)
480
  results["error"] = error_msg
481
+ status.update(label="❌ Prompt too large for model", state="error")
482
  return results
483
 
484
  except Exception as e: