timbmg commited on
Commit
384bb2f
·
unverified ·
1 Parent(s): 97025d7

Refactor prompt construction and token management in LLM inference process, enhancing context handling and logging for code prompts

Browse files
Files changed (3) hide show
  1. app.py +112 -195
  2. core/code_loader_demo.py +23 -6
  3. core/llm_demo.py +4 -1
app.py CHANGED
@@ -34,6 +34,10 @@ logging.basicConfig(
34
  )
35
  logger = logging.getLogger(__name__)
36
 
 
 
 
 
37
  # Page configuration
38
  st.set_page_config(
39
  page_title="SciCoQA Paper- Code Discrepancy Detection",
@@ -43,14 +47,6 @@ st.set_page_config(
43
  )
44
 
45
 
46
- # Constants
47
- MAX_CONTEXT_SIZE = 131072 # Default max context
48
- MAX_TOKENS_BUFFER = 0.9 # Initial buffer (existing)
49
- MIN_TOKENS_BUFFER = 0.5 # Minimum buffer before giving up
50
- BUFFER_REDUCTION_STEP = 0.05 # How much to reduce each retry (5%)
51
- MAX_BUFFER_RETRIES = 10 # Maximum retry attempts
52
-
53
-
54
  def _redact_secrets(text: str, secrets: list[str | None]) -> str:
55
  """Best-effort redaction for secrets that may appear in exception strings/logs."""
56
  redacted = text
@@ -71,37 +67,17 @@ def _safe_model_config_for_session(model_config: dict | None) -> dict | None:
71
  return safe
72
 
73
 
74
- def _is_context_length_error(error_msg: str) -> bool:
75
- """
76
- Check if an error message indicates a context length error.
77
-
78
- Args:
79
- error_msg: The error message string
80
-
81
- Returns:
82
- True if it's a context length error, False otherwise
83
- """
84
- error_lower = error_msg.lower()
85
- return (
86
- "maximum context length" in error_lower
87
- or "requested about" in error_lower
88
- or ("context length is" in error_lower and "you requested" in error_lower)
89
- )
90
-
91
-
92
- def _build_prompt_with_buffer(
93
- buffer_factor: float,
94
  paper_text: str,
95
  code_loader: CodeLoader | None,
96
  code_text: str | None,
97
  model_config: dict,
98
  token_counter: TokenCounter,
99
- ) -> tuple[str, str, int, int]:
100
  """
101
- Build prompt with a specific buffer factor.
102
 
103
  Args:
104
- buffer_factor: Buffer factor to use (e.g., 0.9 for 90%)
105
  paper_text: The paper text
106
  code_loader: CodeLoader instance (if using GitHub repo)
107
  code_text: Raw code text (if using uploaded file)
@@ -109,21 +85,31 @@ def _build_prompt_with_buffer(
109
  token_counter: TokenCounter instance
110
 
111
  Returns:
112
- Tuple of (final_prompt, code_prompt, final_tokens, max_tokens_for_completion)
113
  """
114
  max_context = model_config["max_context"]
 
115
 
116
- # Calculate tokens for paper + prompt template
117
  prompt_template = Prompt("discrepancy_generation")
118
- intermediate_prompt = prompt_template(paper=paper_text, code="")
119
- tokens_intermediate_prompt = token_counter(intermediate_prompt)
120
 
121
- # Calculate remaining tokens for code using the provided buffer factor
122
- max_total_tokens = int(max_context * buffer_factor)
123
- remaining_code_tokens = max_total_tokens - tokens_intermediate_prompt
 
 
 
124
 
125
- logger.info(f"Tokens in intermediate prompt: {tokens_intermediate_prompt}")
126
- logger.info(f"Remaining tokens for code (buffer {buffer_factor:.1%}): {remaining_code_tokens}")
 
 
 
 
 
 
 
 
127
 
128
  # Get code prompt with token limit
129
  if code_loader:
@@ -136,40 +122,31 @@ def _build_prompt_with_buffer(
136
  # Truncate code text to fit within token limit
137
  code_prompt = ""
138
  code_tokens = 0
139
- code_lines = code_text.split('\n')
140
-
141
- for line in code_lines:
142
- line_with_newline = line + '\n'
143
- line_tokens = token_counter(line_with_newline)
144
- if code_tokens + line_tokens > remaining_code_tokens:
145
- logger.warning(f"Truncating code at {code_tokens} tokens (limit: {remaining_code_tokens})")
146
- break
147
- code_prompt += line_with_newline
148
- code_tokens += line_tokens
 
149
 
150
- # Construct final prompt
151
  final_prompt = prompt_template(paper=paper_text, code=code_prompt)
152
  final_tokens = token_counter(final_prompt)
153
- logger.info(f"Total tokens in final prompt: {final_tokens}")
154
 
155
- # Calculate max_tokens for completion (respecting model's context limit)
156
- # Leave some buffer for safety (use 95% of remaining context)
157
- remaining_for_completion = max_context - final_tokens
158
-
159
- if remaining_for_completion <= 0:
160
  raise ValueError(
161
- f"Prompt too long: {final_tokens} tokens exceeds model's context limit of {max_context} tokens"
 
162
  )
163
 
164
- # Use 95% of remaining to be safe, but ensure at least some tokens
165
- max_tokens_for_completion = max(1, int(remaining_for_completion * 0.95))
166
 
167
- logger.info(
168
- f"Max context: {max_context}, Input tokens: {final_tokens}, "
169
- f"Remaining: {remaining_for_completion}, Max completion tokens: {max_tokens_for_completion}"
170
- )
171
-
172
- return final_prompt, code_prompt, final_tokens, max_tokens_for_completion
173
 
174
 
175
  def validate_urls(arxiv_url: str, github_url: str) -> tuple[bool, str]:
@@ -314,18 +291,17 @@ def process_discrepancy_detection(
314
  state="running",
315
  )
316
 
317
- # Step 5: Calculate tokens and prepare prompt (initial build)
318
  step_start = time.time()
319
  status.update(label="📝 Preparing prompt...", state="running")
320
 
321
- # Create token counter (needed for both Step 5 and Step 6 retry loop)
322
  tokenizer_name = model_config["tokenizer"]
323
  token_counter = TokenCounter(model=tokenizer_name)
324
 
325
  try:
326
- # Build initial prompt with default buffer
327
- final_prompt, code_prompt, final_tokens, max_tokens_for_completion = _build_prompt_with_buffer(
328
- buffer_factor=MAX_TOKENS_BUFFER,
329
  paper_text=paper_text,
330
  code_loader=code_loader,
331
  code_text=code_text,
@@ -338,7 +314,7 @@ def process_discrepancy_detection(
338
 
339
  step_time = time.time() - step_start
340
  step_timings["Prompt Preparation"] = step_time
341
- st.write(f"✅ Prompt prepared: {step_time:.1f}s ({final_tokens:,} tokens, max output: {max_tokens_for_completion:,} tokens)")
342
  status.update(
343
  label=f"✅ Prompt prepared ({step_time:.1f}s, {final_tokens:,} tokens)",
344
  state="running",
@@ -350,135 +326,76 @@ def process_discrepancy_detection(
350
  status.update(label="❌ Error preparing prompt", state="error")
351
  return results
352
 
353
- # Step 6: Detect discrepancies with LLM (with retry on context length errors)
354
  step_start = time.time()
355
  status.update(label="🤖\uFE0F Detecting discrepancies (this may take a while)...", state="running")
356
 
357
- # Retry configuration
358
- initial_buffer = MAX_TOKENS_BUFFER # 0.9
359
- min_buffer = MIN_TOKENS_BUFFER # 0.5
360
- buffer_reduction_step = BUFFER_REDUCTION_STEP # 0.05
361
- max_retries = MAX_BUFFER_RETRIES # 5
362
-
363
- current_buffer = initial_buffer
364
- retry_count = 0
365
- success = False
366
- current_final_prompt = final_prompt
367
- current_max_tokens_for_completion = max_tokens_for_completion
368
-
369
- while not success and current_buffer >= min_buffer and retry_count < max_retries:
370
- try:
371
- # Rebuild prompt with current buffer (if retry, otherwise use existing)
372
- if retry_count > 0:
373
- status.update(
374
- label=f"🔄 Retrying with reduced buffer ({current_buffer:.1%})...",
375
- state="running"
376
- )
377
- st.write(f"🔄 Retrying with reduced buffer ({current_buffer:.1%})...")
378
-
379
- # Rebuild prompt with reduced buffer
380
- current_final_prompt, code_prompt, final_tokens, current_max_tokens_for_completion = _build_prompt_with_buffer(
381
- buffer_factor=current_buffer,
382
- paper_text=paper_text,
383
- code_loader=code_loader,
384
- code_text=code_text,
385
- model_config=model_config,
386
- token_counter=token_counter,
387
- )
388
- results["code_prompt"] = code_prompt
389
- results["prompt"] = current_final_prompt
390
-
391
- # Extract model configuration
392
- model = model_config["model"]
393
- api_key = model_config.get("api_key")
394
- api_base = model_config.get("api_base")
395
- max_context = model_config.get("max_context")
396
-
397
- llm = LLM(
398
- model=model,
399
- api_key=api_key,
400
- api_base=api_base,
401
- temperature=1.0,
402
- top_p=1.0,
403
- reasoning_effort="high",
404
- max_context=max_context,
405
- max_tokens=current_max_tokens_for_completion, # Respect model's context limit
 
 
 
406
  )
407
-
408
- response = llm(current_final_prompt)
409
- results["llm_response"] = response
410
-
411
- # Extract content from response
412
- choices = response.get("choices", [])
413
- if not choices:
414
- raise ValueError("No choices in LLM response")
415
-
416
- content = (
417
- choices[0]
418
- .get("message", {})
419
- .get("content", "")
420
  )
421
 
422
- if not content:
423
- raise ValueError("Empty content in LLM response")
424
-
425
- # Parse discrepancies
426
- discrepancies = parse_discrepancies(content)
427
- results["discrepancies"] = discrepancies
428
-
429
- step_time = time.time() - step_start
430
- step_timings["LLM Inference"] = step_time
431
- total_time = sum(step_timings.values())
432
-
433
- st.write(f"✅ LLM inference: {step_time:.1f}s")
434
- st.write("---")
435
- st.write(f"**Total time: {total_time:.1f}s**")
436
-
437
- if discrepancies:
438
- count = len(discrepancies)
439
- discrepancy_text = "discrepancy" if count == 1 else "discrepancies"
440
- status.update(
441
- label=f"✅ Complete! Found {count} {discrepancy_text} ({total_time:.1f}s total)",
442
- state="complete",
443
- )
444
- else:
445
- status.update(
446
- label=f"✅ Complete! No discrepancies found ({total_time:.1f}s total)",
447
- state="complete",
448
- )
449
-
450
- success = True
451
-
452
- except Exception as e:
453
- error_msg = str(e)
454
- api_key = model_config.get("api_key") if isinstance(model_config, dict) else None
455
- redacted_error = _redact_secrets(error_msg, [api_key])
456
-
457
- # Check if it's a context length error
458
- if _is_context_length_error(error_msg) and current_buffer > min_buffer:
459
- retry_count += 1
460
- current_buffer -= buffer_reduction_step
461
- logger.warning(
462
- f"Context length error detected. Retrying with buffer {current_buffer:.1%} "
463
- f"(attempt {retry_count}/{max_retries})"
464
- )
465
- continue # Retry with smaller buffer
466
- else:
467
- # Not a context length error, or we've exhausted retries
468
- logger.error(f"Error during LLM inference: {redacted_error}")
469
- results["error"] = f"Error during LLM inference: {redacted_error}"
470
- status.update(label="❌ Error during inference", state="error")
471
- return results
472
-
473
- # If we exhausted retries or hit minimum buffer
474
- if not success:
475
- error_msg = (
476
- f"Could not fit prompt within context limits after {retry_count} retries. "
477
- f"Minimum buffer ({min_buffer:.1%}) reached."
478
- )
479
- logger.error(error_msg)
480
- results["error"] = error_msg
481
- status.update(label="❌ Prompt too large for model", state="error")
482
  return results
483
 
484
  except Exception as e:
 
34
  )
35
  logger = logging.getLogger(__name__)
36
 
37
+ # Constants
38
+ CONTEXT_BUFFER_FACTOR = 0.9
39
+ MAX_CONTEXT_SIZE = 131072 # Default max context
40
+
41
  # Page configuration
42
  st.set_page_config(
43
  page_title="SciCoQA Paper- Code Discrepancy Detection",
 
47
  )
48
 
49
 
 
 
 
 
 
 
 
 
50
  def _redact_secrets(text: str, secrets: list[str | None]) -> str:
51
  """Best-effort redaction for secrets that may appear in exception strings/logs."""
52
  redacted = text
 
67
  return safe
68
 
69
 
70
+ def _build_prompt(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  paper_text: str,
72
  code_loader: CodeLoader | None,
73
  code_text: str | None,
74
  model_config: dict,
75
  token_counter: TokenCounter,
76
+ ) -> tuple[str, str, int]:
77
  """
78
+ Build prompt by counting tokens and truncating code until prompt + paper + code < CONTEXT_BUFFER_FACTOR * model context length.
79
 
80
  Args:
 
81
  paper_text: The paper text
82
  code_loader: CodeLoader instance (if using GitHub repo)
83
  code_text: Raw code text (if using uploaded file)
 
85
  token_counter: TokenCounter instance
86
 
87
  Returns:
88
+ Tuple of (final_prompt, code_prompt, final_tokens)
89
  """
90
  max_context = model_config["max_context"]
91
+ max_total_tokens = int(max_context * CONTEXT_BUFFER_FACTOR)
92
 
93
+ # Build prompt template
94
  prompt_template = Prompt("discrepancy_generation")
 
 
95
 
96
+ # Calculate tokens for template + paper
97
+ template_with_paper = prompt_template(paper=paper_text, code="")
98
+ tokens_template_and_paper = token_counter(template_with_paper)
99
+
100
+ # Calculate remaining tokens for code
101
+ remaining_code_tokens = max_total_tokens - tokens_template_and_paper
102
 
103
+ if remaining_code_tokens <= 0:
104
+ raise ValueError(
105
+ f"Paper text too long: {tokens_template_and_paper} tokens exceeds "
106
+ f"90% of context limit ({max_total_tokens} tokens)"
107
+ )
108
+
109
+ logger.info(
110
+ f"Template + paper tokens: {tokens_template_and_paper}, "
111
+ f"Remaining for code: {remaining_code_tokens}"
112
+ )
113
 
114
  # Get code prompt with token limit
115
  if code_loader:
 
122
  # Truncate code text to fit within token limit
123
  code_prompt = ""
124
  code_tokens = 0
125
+ if code_text and remaining_code_tokens > 0:
126
+ code_lines = code_text.split('\n')
127
+
128
+ for line in code_lines:
129
+ line_with_newline = line + '\n'
130
+ line_tokens = token_counter(line_with_newline)
131
+ if code_tokens + line_tokens > remaining_code_tokens:
132
+ logger.warning(f"Truncating code at {code_tokens} tokens (limit: {remaining_code_tokens})")
133
+ break
134
+ code_prompt += line_with_newline
135
+ code_tokens += line_tokens
136
 
137
+ # Construct final prompt and verify it's within limit
138
  final_prompt = prompt_template(paper=paper_text, code=code_prompt)
139
  final_tokens = token_counter(final_prompt)
 
140
 
141
+ if final_tokens > max_total_tokens:
 
 
 
 
142
  raise ValueError(
143
+ f"Final prompt too long: {final_tokens} tokens exceeds "
144
+ f"90% of context limit ({max_total_tokens} tokens)"
145
  )
146
 
147
+ logger.info(f"Final prompt tokens: {final_tokens} (limit: {max_total_tokens})")
 
148
 
149
+ return final_prompt, code_prompt, final_tokens
 
 
 
 
 
150
 
151
 
152
  def validate_urls(arxiv_url: str, github_url: str) -> tuple[bool, str]:
 
291
  state="running",
292
  )
293
 
294
+ # Step 5: Calculate tokens and prepare prompt
295
  step_start = time.time()
296
  status.update(label="📝 Preparing prompt...", state="running")
297
 
298
+ # Create token counter
299
  tokenizer_name = model_config["tokenizer"]
300
  token_counter = TokenCounter(model=tokenizer_name)
301
 
302
  try:
303
+ # Build prompt with simple token counting
304
+ final_prompt, code_prompt, final_tokens = _build_prompt(
 
305
  paper_text=paper_text,
306
  code_loader=code_loader,
307
  code_text=code_text,
 
314
 
315
  step_time = time.time() - step_start
316
  step_timings["Prompt Preparation"] = step_time
317
+ st.write(f"✅ Prompt prepared: {step_time:.1f}s ({final_tokens:,} tokens)")
318
  status.update(
319
  label=f"✅ Prompt prepared ({step_time:.1f}s, {final_tokens:,} tokens)",
320
  state="running",
 
326
  status.update(label="❌ Error preparing prompt", state="error")
327
  return results
328
 
329
+ # Step 6: Detect discrepancies with LLM
330
  step_start = time.time()
331
  status.update(label="🤖\uFE0F Detecting discrepancies (this may take a while)...", state="running")
332
 
333
+ try:
334
+ # Extract model configuration
335
+ model = model_config["model"]
336
+ api_key = model_config.get("api_key")
337
+ api_base = model_config.get("api_base")
338
+ max_context = model_config.get("max_context")
339
+
340
+ llm = LLM(
341
+ model=model,
342
+ api_key=api_key,
343
+ api_base=api_base,
344
+ temperature=1.0,
345
+ top_p=1.0,
346
+ reasoning_effort="high",
347
+ max_context=max_context,
348
+ )
349
+
350
+ response = llm(final_prompt)
351
+ results["llm_response"] = response
352
+
353
+ # Extract content from response
354
+ choices = response.get("choices", [])
355
+ if not choices:
356
+ raise ValueError("No choices in LLM response")
357
+
358
+ content = (
359
+ choices[0]
360
+ .get("message", {})
361
+ .get("content", "")
362
+ )
363
+
364
+ if not content:
365
+ raise ValueError("Empty content in LLM response")
366
+
367
+ # Parse discrepancies
368
+ discrepancies = parse_discrepancies(content)
369
+ results["discrepancies"] = discrepancies
370
+
371
+ step_time = time.time() - step_start
372
+ step_timings["LLM Inference"] = step_time
373
+ total_time = sum(step_timings.values())
374
+
375
+ st.write(f"✅ LLM inference: {step_time:.1f}s")
376
+ st.write("---")
377
+ st.write(f"**Total time: {total_time:.1f}s**")
378
+
379
+ if discrepancies:
380
+ count = len(discrepancies)
381
+ discrepancy_text = "discrepancy" if count == 1 else "discrepancies"
382
+ status.update(
383
+ label=f"✅ Complete! Found {count} {discrepancy_text} ({total_time:.1f}s total)",
384
+ state="complete",
385
  )
386
+ else:
387
+ status.update(
388
+ label=f" Complete! No discrepancies found ({total_time:.1f}s total)",
389
+ state="complete",
 
 
 
 
 
 
 
 
 
390
  )
391
 
392
+ except Exception as e:
393
+ error_msg = str(e)
394
+ api_key = model_config.get("api_key") if isinstance(model_config, dict) else None
395
+ redacted_error = _redact_secrets(error_msg, [api_key])
396
+ logger.error(f"Error during LLM inference: {redacted_error}")
397
+ results["error"] = f"Error during LLM inference: {redacted_error}"
398
+ status.update(label="❌ Error during inference", state="error")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  return results
400
 
401
  except Exception as e:
core/code_loader_demo.py CHANGED
@@ -255,6 +255,12 @@ class CodeLoader:
255
  """Generate code prompt with repo tree and file contents."""
256
  code_prompt = "Repo tree:\n" + self.get_repo_tree() + "\n\n"
257
  tokens = token_counter(code_prompt) if token_counter is not None else 0
 
 
 
 
 
 
258
 
259
  files_to_replace = {}
260
  if code_changes:
@@ -275,18 +281,29 @@ class CodeLoader:
275
  if token_counter is not None:
276
  logger.debug(f"Adding file: {file_path}")
277
  num_tokens = token_counter(code_file)
 
 
 
 
 
 
 
 
278
  tokens += num_tokens
279
  logger.debug(
280
  f"Number of tokens in file: {num_tokens}. "
281
  f"Total number of tokens in code prompt: {tokens}"
282
  )
283
- if max_tokens and tokens > max_tokens:
284
- logger.warning(
285
- f"Truncating. Max tokens reached for {self.github_url}. "
286
- f"Max tokens for code is {max_tokens}"
287
- )
288
- break
289
  code_prompt += code_file
 
 
 
 
 
 
 
 
 
290
  return code_prompt
291
 
292
 
 
255
  """Generate code prompt with repo tree and file contents."""
256
  code_prompt = "Repo tree:\n" + self.get_repo_tree() + "\n\n"
257
  tokens = token_counter(code_prompt) if token_counter is not None else 0
258
+
259
+ if token_counter is not None and max_tokens is not None:
260
+ logger.info(
261
+ f"Building code prompt: repo tree tokens={tokens}, max_tokens={max_tokens}, "
262
+ f"remaining for files={max_tokens - tokens}"
263
+ )
264
 
265
  files_to_replace = {}
266
  if code_changes:
 
281
  if token_counter is not None:
282
  logger.debug(f"Adding file: {file_path}")
283
  num_tokens = token_counter(code_file)
284
+ # Check if adding this file would exceed the limit BEFORE adding it
285
+ if max_tokens and (tokens + num_tokens) > max_tokens:
286
+ logger.warning(
287
+ f"Truncating. Max tokens reached for {self.github_url}. "
288
+ f"Current tokens: {tokens}, File tokens: {num_tokens}, "
289
+ f"Max tokens for code is {max_tokens}"
290
+ )
291
+ break
292
  tokens += num_tokens
293
  logger.debug(
294
  f"Number of tokens in file: {num_tokens}. "
295
  f"Total number of tokens in code prompt: {tokens}"
296
  )
 
 
 
 
 
 
297
  code_prompt += code_file
298
+
299
+ # Log final code prompt size
300
+ if token_counter is not None:
301
+ final_code_tokens = token_counter(code_prompt)
302
+ logger.info(
303
+ f"Code prompt built: {final_code_tokens} tokens "
304
+ f"(max was {max_tokens if max_tokens else 'unlimited'})"
305
+ )
306
+
307
  return code_prompt
308
 
309
 
core/llm_demo.py CHANGED
@@ -72,8 +72,11 @@ class LLM:
72
  kwargs = {
73
  "model": self.model,
74
  "messages": [{"role": "user", "content": prompt}],
75
- "max_tokens": self.max_tokens,
76
  }
 
 
 
 
77
 
78
  # Let LiteLLM drop unsupported params per-provider/model (e.g., GPT-5 rejecting top_p)
79
  if self.drop_params:
 
72
  kwargs = {
73
  "model": self.model,
74
  "messages": [{"role": "user", "content": prompt}],
 
75
  }
76
+
77
+ # Only set max_tokens if explicitly provided (let API use default if None)
78
+ if self.max_tokens is not None:
79
+ kwargs["max_tokens"] = self.max_tokens
80
 
81
  # Let LiteLLM drop unsupported params per-provider/model (e.g., GPT-5 rejecting top_p)
82
  if self.drop_params: