SolshineMisfit commited on
Commit
f3ea50e
·
verified ·
1 Parent(s): 53ed647

Implemented a context management method that can monitor the token count and summarize or clear parts of the context when it gets too large

Browse files
Files changed (1) hide show
  1. app.py +118 -3
app.py CHANGED
@@ -359,12 +359,127 @@ model = HfApiModel(
359
  )
360
 
361
  # Add fallback logic that only activates if the primary model fails
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  def try_model_call_with_fallbacks(prompt):
363
  """Try to use the primary model first, fall back to alternatives if it fails."""
364
  # First attempt with primary model
365
  try:
366
- return model(prompt)
 
 
 
367
  except Exception as primary_error:
 
 
 
 
 
 
 
 
 
368
  print(f"Primary model call failed: {str(primary_error)}")
369
  print("Trying fallback models...")
370
 
@@ -392,11 +507,11 @@ def try_model_call_with_fallbacks(prompt):
392
  try:
393
  print(f"Trying fallback model: {fallback['display_name']}")
394
  client = InferenceClient(provider=fallback["provider"], api_key=api_key)
395
- messages = [{"role": "user", "content": prompt}]
396
  completion = client.chat.completions.create(
397
  model=fallback["model_name"],
398
  messages=messages,
399
- max_tokens=2096,
400
  temperature=0.5
401
  )
402
  print(f"Successfully used fallback model: {fallback['display_name']}")
 
359
  )
360
 
361
  # Add fallback logic that only activates if the primary model fails
362
+ def manage_context(prompt, max_allowed_tokens=30000):
363
+ """Manages large contexts by summarizing or trimming when they get too big.
364
+
365
+ This helps avoid the 'inputs tokens + max_new_tokens must be <= 32768' error
366
+ by keeping the context size under control.
367
+
368
+ Args:
369
+ prompt: The full context/prompt that might be too large
370
+ max_allowed_tokens: Maximum number of tokens to allow before trimming
371
+
372
+ Returns:
373
+ A potentially shortened/summarized version of the prompt
374
+ """
375
+ # Rough token estimation (splitting on spaces is a crude approximation)
376
+ estimated_tokens = len(prompt.split())
377
+
378
+ # If below threshold, return as is
379
+ if estimated_tokens <= max_allowed_tokens:
380
+ return prompt
381
+
382
+ print(f"WARNING: Context size ({estimated_tokens} estimated tokens) exceeds limit ({max_allowed_tokens})")
383
+
384
+ # For extremely large prompts, we need more aggressive handling
385
+ if estimated_tokens > max_allowed_tokens * 1.5:
386
+ print("Performing aggressive context management")
387
+
388
+ # Approach 1: Keep only the most recent parts of the conversation
389
+ lines = prompt.strip().split('\n')
390
+
391
+ # Identify structural elements to keep
392
+ instruction_idx = -1
393
+ for i, line in enumerate(lines):
394
+ if "You are a" in line or "I want you to" in line:
395
+ instruction_idx = i
396
+
397
+ # Always keep the first part with instructions (system prompt)
398
+ keep_beginning = lines[:instruction_idx + 20] if instruction_idx >= 0 else lines[:50]
399
+
400
+ # Keep the most recent content (approximately half of the max tokens)
401
+ keep_end = lines[-int(max_allowed_tokens/15):]
402
+
403
+ # Add a note about trimming
404
+ middle_note = [
405
+ "",
406
+ "...",
407
+ "[Context has been trimmed to fit token limits]",
408
+ "...",
409
+ ""
410
+ ]
411
+
412
+ # Combine parts
413
+ shortened_prompt = "\n".join(keep_beginning + middle_note + keep_end)
414
+ print(f"Context reduced from ~{estimated_tokens} to ~{len(shortened_prompt.split())} estimated tokens")
415
+ return shortened_prompt
416
+
417
+ # Moderate size reduction for moderately oversized prompts
418
+ else:
419
+ print("Performing moderate context management")
420
+
421
+ # Split into lines for easier processing
422
+ sections = prompt.split("\n\n")
423
+
424
+ # Keep important sections like system instructions and recent content
425
+ # Identify which sections to keep or trim
426
+ keep_sections = []
427
+ trim_sections = []
428
+
429
+ # Process each section
430
+ for i, section in enumerate(sections):
431
+ # Always keep the first few sections (likely instructions)
432
+ if i < 3:
433
+ keep_sections.append(section)
434
+ # Keep the last several sections (most recent and relevant)
435
+ elif i > len(sections) - 8:
436
+ keep_sections.append(section)
437
+ # For code blocks, we should generally keep them
438
+ elif "```" in section:
439
+ keep_sections.append(section)
440
+ # For very short sections, keep them
441
+ elif len(section.split()) < 30:
442
+ keep_sections.append(section)
443
+ # For sections with likely important content, keep them
444
+ elif any(marker in section.lower() for marker in ["important", "key", "critical", "necessary", "must"]):
445
+ keep_sections.append(section)
446
+ # Otherwise, candidate for trimming
447
+ else:
448
+ trim_sections.append(section)
449
+
450
+ # If we still need to trim more, start removing some of the trim_sections
451
+ if len(" ".join(keep_sections).split()) > max_allowed_tokens * 0.8:
452
+ # Keep only a portion of the trim_sections
453
+ trim_to_keep = int(len(trim_sections) * 0.3) # Keep 30%
454
+ trim_sections = trim_sections[:trim_to_keep]
455
+
456
+ # Build final prompt with a note about trimming
457
+ final_sections = keep_sections + ["[Some context has been summarized to fit token limits]"] + trim_sections
458
+ final_prompt = "\n\n".join(final_sections)
459
+
460
+ print(f"Context reduced from ~{estimated_tokens} to ~{len(final_prompt.split())} estimated tokens")
461
+ return final_prompt
462
+
463
+
464
+ # Now update the try_model_call_with_fallbacks function to use this context management
465
  def try_model_call_with_fallbacks(prompt):
466
  """Try to use the primary model first, fall back to alternatives if it fails."""
467
  # First attempt with primary model
468
  try:
469
+ # Apply context management
470
+ managed_prompt = manage_context(prompt)
471
+
472
+ return original_call(managed_prompt)
473
  except Exception as primary_error:
474
+ # If it's a token limit error, try more aggressive management
475
+ if "Input validation error: inputs tokens + max_new_tokens" in str(primary_error):
476
+ try:
477
+ print("Token limit exceeded. Trying more aggressive context management...")
478
+ more_managed_prompt = manage_context(prompt, max_allowed_tokens=20000)
479
+ return original_call(more_managed_prompt)
480
+ except Exception:
481
+ print("Token reduction failed. Proceeding to fallback models...")
482
+
483
  print(f"Primary model call failed: {str(primary_error)}")
484
  print("Trying fallback models...")
485
 
 
507
  try:
508
  print(f"Trying fallback model: {fallback['display_name']}")
509
  client = InferenceClient(provider=fallback["provider"], api_key=api_key)
510
+ messages = [{"role": "user", "content": manage_context(prompt, 25000)}] # Apply context management for fallbacks too
511
  completion = client.chat.completions.create(
512
  model=fallback["model_name"],
513
  messages=messages,
514
+ max_tokens=1800,
515
  temperature=0.5
516
  )
517
  print(f"Successfully used fallback model: {fallback['display_name']}")