gabejavitt commited on
Commit
ac72f74
·
verified ·
1 Parent(s): f6e496f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -133
app.py CHANGED
@@ -97,6 +97,80 @@ def retry_with_backoff(max_retries=None, base_delay=None):
97
  return wrapper
98
  return decorator
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  class SearchCache:
102
  """LRU cache for search results"""
@@ -517,120 +591,94 @@ class ValidateInput(BaseModel):
517
  proposed_answer: str = Field(description="Answer to validate")
518
  original_question: str = Field(description="Original question (first 100 chars)")
519
 
520
- @tool(args_schema=ValidateInput)
521
- def validate_answer(proposed_answer: str, original_question: str) -> str:
522
  """
523
- ENHANCED: Validate answer before submission with comprehensive checks.
524
-
525
- ALWAYS use before final_answer_tool.
526
  """
527
  start_time = time.time()
 
528
  try:
529
- print(f"✓ Validating: '{proposed_answer[:50]}...'")
530
 
531
- issues = []
532
  warnings = []
533
- suggestions = []
534
-
535
- # 1. Check conversational fluff
536
- fluff = ["the answer is", "based on", "according to", "i found", "here is",
537
- "here's", "after searching", "from my research", "the result is"]
538
- if any(p in proposed_answer.lower() for p in fluff):
539
- issues.append("❌ Remove conversational text - answer ONLY")
540
-
541
- # 2. Check code fences
542
- if "```" in proposed_answer:
543
- issues.append("❌ Remove code fences (```)")
544
-
545
- # 3. Check markdown formatting
546
- if proposed_answer.startswith('#') or '**' in proposed_answer:
547
- issues.append("❌ Remove markdown formatting")
548
-
549
- # 4. Check length appropriateness
550
- question_lower = original_question.lower()
551
- if len(proposed_answer) > 500:
552
- if not any(k in question_lower for k in ['explain', 'describe', 'why', 'how does']):
553
- warnings.append("⚠️ Answer very long. Question asks for short answer?")
554
-
555
- # 5. Check for number questions
556
- number_keywords = ["how many", "what number", "count", "total", "sum",
557
- "what year", "when did", "what date"]
558
- if any(k in question_lower for k in number_keywords):
559
- if not any(c.isdigit() for c in proposed_answer):
560
- issues.append("❌ Question asks for number but answer has no digits")
561
- else:
562
- # Extract just the number(s)
563
- import re
564
- numbers = re.findall(r'\d+(?:\.\d+)?', proposed_answer)
565
- if numbers and len(proposed_answer) > 50:
566
- suggestions.append(f"💡 Consider just the number(s): {', '.join(numbers)}")
567
-
568
- # 6. Check for list questions
569
- list_keywords = ["list", "what are", "name the", "which"]
570
- if any(k in question_lower for k in list_keywords):
571
- if '\n' in proposed_answer or len(proposed_answer.split(',')) > 1:
572
- # Good, it's formatted as a list
573
- pass
574
- else:
575
- warnings.append("⚠️ Question might ask for multiple items")
576
-
577
- # 7. Check for yes/no questions
578
- if question_lower.startswith(('is ', 'does ', 'did ', 'can ', 'will ', 'was ', 'were ', 'are ')):
579
- if proposed_answer.lower() not in ['yes', 'no', 'true', 'false']:
580
- if not proposed_answer.lower().startswith(('yes', 'no')):
581
- warnings.append("⚠️ Question seems yes/no. Answer should start with yes/no?")
582
-
583
- # 8. Check for excessive punctuation
584
- if proposed_answer.count('!') > 2 or proposed_answer.count('?') > 1:
585
- issues.append("❌ Remove excessive punctuation")
586
-
587
- # 9. Check for quotes around answer
588
- if (proposed_answer.startswith('"') and proposed_answer.endswith('"')) or \
589
- (proposed_answer.startswith("'") and proposed_answer.endswith("'")):
590
- suggestions.append("💡 Consider removing quotes around answer")
591
-
592
- # 10. Check for multiple sentences when one expected
593
- sentences = [s.strip() for s in proposed_answer.split('.') if s.strip()]
594
- if len(sentences) > 3:
595
- if not any(k in question_lower for k in ['explain', 'describe', 'why', 'how']):
596
- warnings.append("⚠️ Multiple sentences. Question asks for simple answer?")
597
-
598
- # 11. Sanity check: is it empty?
599
- if not proposed_answer.strip():
600
- issues.append("❌ Answer is empty!")
601
-
602
- # 12. Check for units in measurement questions
603
- unit_keywords = ['height', 'weight', 'distance', 'speed', 'temperature', 'size']
604
- if any(k in question_lower for k in unit_keywords):
605
- has_unit = any(u in proposed_answer.lower() for u in
606
- ['km', 'miles', 'kg', 'lbs', 'cm', 'inches', 'celsius',
607
- 'fahrenheit', 'mph', 'kph', 'meters', 'feet'])
608
- if not has_unit and any(c.isdigit() for c in proposed_answer):
609
- warnings.append("⚠️ Measurement question but no unit found")
610
-
611
- # Build response
612
- if issues:
613
- result = "🚫 VALIDATION FAILED:\n" + "\n".join(issues)
614
- if suggestions:
615
- result += "\n\nSuggestions:\n" + "\n".join(suggestions)
616
- result += "\n\nFix issues then retry validation."
617
- elif warnings:
618
- result = "⚠️ WARNINGS:\n" + "\n".join(warnings)
619
- if suggestions:
620
- result += "\n\nSuggestions:\n" + "\n".join(suggestions)
621
- result += "\n\nProceed if confident, or refine answer."
622
- elif suggestions:
623
- result = "✅ PASSED with suggestions:\n" + "\n".join(suggestions)
624
- result += "\n\nCall final_answer_tool() when ready."
625
  else:
626
- result = "✅ VALIDATION PASSED! Call final_answer_tool() now."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
 
628
  telemetry.record_call("validate_answer", time.time() - start_time, True)
629
  return result
630
 
631
  except Exception as e:
632
  telemetry.record_call("validate_answer", time.time() - start_time, False)
633
- raise
634
 
635
  # =============================================================================
636
  # CORE TOOLS
@@ -1882,16 +1930,28 @@ class FinalAnswerInput(BaseModel):
1882
 
1883
  @tool(args_schema=FinalAnswerInput)
1884
  def final_answer_tool(answer: str) -> str:
1885
- """Submit final answer"""
1886
  start_time = time.time()
1887
 
1888
  try:
1889
- print(f"✅ FINAL ANSWER: {answer}")
 
 
 
 
 
 
 
 
 
 
 
1890
  telemetry.record_call("final_answer_tool", time.time() - start_time, True)
1891
- return answer
 
1892
  except Exception as e:
1893
  telemetry.record_call("final_answer_tool", time.time() - start_time, False)
1894
- raise
1895
 
1896
 
1897
  # =============================================================================
@@ -2297,34 +2357,34 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format.
2297
  self.current_llm = "groq"
2298
 
2299
  def prune_context_if_needed(state: AgentState) -> AgentState:
2300
- """
2301
- Prune conversation history if it's getting too long.
2302
- Keeps system message + recent history to stay under token limits.
2303
- """
2304
- messages = state.get("messages", [])
2305
-
2306
- # Keep first message (system prompt) + last N messages
2307
- MAX_MESSAGES = 20 # Adjust based on your needs
2308
-
2309
- if len(messages) > MAX_MESSAGES:
2310
- print(f"⚠️ Context pruning: {len(messages)} messages → {MAX_MESSAGES}")
2311
-
2312
- # Always keep system message (if it exists)
2313
- system_msg = None
2314
- if messages and isinstance(messages[0], SystemMessage):
2315
- system_msg = messages[0]
2316
- messages = messages[1:]
2317
-
2318
- # Keep only recent messages
2319
- recent_messages = messages[-(MAX_MESSAGES-1):]
2320
-
2321
- # Reconstruct
2322
- if system_msg:
2323
- state["messages"] = [system_msg] + recent_messages
2324
- else:
2325
- state["messages"] = recent_messages
2326
-
2327
- return state
2328
 
2329
  # Build agent graph
2330
  def agent_node(state: AgentState):
 
97
  return wrapper
98
  return decorator
99
 
100
+ def normalize_answer(answer: str, question: str = "") -> str:
101
+ """
102
+ Normalize answer to match expected format.
103
+
104
+ Args:
105
+ answer: The answer to normalize
106
+ question: Optional question text to determine if order matters
107
+ """
108
+ if not answer:
109
+ return answer
110
+
111
+ original = answer
112
+ answer = answer.strip()
113
+
114
+ # Remove common prefixes
115
+ prefixes_to_remove = [
116
+ "the answer is:",
117
+ "the answer is",
118
+ "answer:",
119
+ "final answer:",
120
+ "result:",
121
+ ]
122
+ for prefix in prefixes_to_remove:
123
+ if answer.lower().startswith(prefix):
124
+ answer = answer[len(prefix):].strip()
125
+
126
+ # Handle lists
127
+ if "," in answer:
128
+ items = [item.strip() for item in answer.split(",")]
129
+ items = [item for item in items if item]
130
+
131
+ # Determine if order matters based on question
132
+ order_matters_keywords = [
133
+ "first", "last", "before", "after", "sequence",
134
+ "order", "chronological", "oldest", "newest",
135
+ "in the form", "format"
136
+ ]
137
+
138
+ order_matters = any(kw in question.lower() for kw in order_matters_keywords)
139
+
140
+ if not order_matters:
141
+ # Sort alphabetically for consistency
142
+ items.sort()
143
+ print(f" 📋 Sorted list alphabetically (order doesn't seem to matter)")
144
+ else:
145
+ print(f" 📋 Kept original order (question specifies order)")
146
+
147
+ # Normalize each item
148
+ items = [item.strip().rstrip('.') for item in items]
149
+
150
+ # Consistent spacing
151
+ answer = ", ".join(items)
152
+
153
+ # Single word capitalization
154
+ if len(answer.split()) == 1:
155
+ if answer.lower() in ['right', 'left', 'yes', 'no', 'true', 'false']:
156
+ answer = answer.capitalize()
157
+
158
+ # Handle "St." vs "Saint"
159
+ if "without abbreviations" in question.lower():
160
+ answer = answer.replace("St.", "Saint")
161
+ answer = answer.replace("Dr.", "Doctor")
162
+ answer = answer.replace("Mt.", "Mount")
163
+
164
+ # Remove trailing period (unless decimal)
165
+ if answer.endswith('.') and not (len(answer) > 1 and answer[-2].isdigit()):
166
+ answer = answer[:-1]
167
+
168
+ # Remove wrapping quotes
169
+ if (answer.startswith('"') and answer.endswith('"')) or \
170
+ (answer.startswith("'") and answer.endswith("'")):
171
+ answer = answer[1:-1]
172
+
173
+ return answer
174
 
175
  class SearchCache:
176
  """LRU cache for search results"""
 
591
  proposed_answer: str = Field(description="Answer to validate")
592
  original_question: str = Field(description="Original question (first 100 chars)")
593
 
594
+ @tool(args_schema=ValidateAnswerInput)
595
+ def validate_answer(answer: str) -> str:
596
  """
597
+ Validate answer format and provide warnings.
598
+ Returns validation result with normalization suggestions.
 
599
  """
600
  start_time = time.time()
601
+
602
  try:
603
+ print(f"✓ Validating: '{answer[:50]}...'")
604
 
 
605
  warnings = []
606
+ errors = []
607
+ normalization_needed = []
608
+
609
+ # Normalize for validation
610
+ normalized = normalize_answer(answer)
611
+
612
+ if normalized != answer:
613
+ normalization_needed.append(f"Consider using normalized form: '{normalized}'")
614
+
615
+ # Check 1: Empty answer
616
+ if not answer or not answer.strip():
617
+ errors.append("Answer is empty")
618
+
619
+ # Check 2: Too long (probably explaining instead of answering)
620
+ if len(answer) > 200:
621
+ warnings.append("Answer is very long (>200 chars). Consider if question asks for brief response.")
622
+
623
+ # Check 3: Contains question words
624
+ question_words = ['what', 'who', 'when', 'where', 'why', 'how', 'which']
625
+ if any(word in answer.lower() for word in question_words):
626
+ warnings.append("Answer contains question words. Make sure you're providing the answer, not rephrasing the question.")
627
+
628
+ # Check 4: List ordering
629
+ if "," in answer:
630
+ items = [item.strip() for item in answer.split(",")]
631
+ if len(items) > 1:
632
+ warnings.append(f"List detected with {len(items)} items. Verify order matches question requirements.")
633
+
634
+ # Check 5: Capitalization consistency
635
+ if answer.lower() in ['right', 'left', 'yes', 'no', 'true', 'false']:
636
+ if not answer[0].isupper():
637
+ normalization_needed.append(f"Consider capitalizing: '{answer.capitalize()}'")
638
+
639
+ # Check 6: Abbreviations
640
+ if any(abbrev in answer.lower() for abbrev in ['st.', 'dr.', 'mt.']):
641
+ if "without abbreviations" in str(answer).lower() or "full" in str(answer).lower():
642
+ warnings.append("Question may ask for full form without abbreviations")
643
+
644
+ # Check 7: Spacing in lists
645
+ if "," in answer:
646
+ # Check for inconsistent spacing
647
+ if ", " in answer and "," in answer.replace(", ", ""):
648
+ normalization_needed.append("Inconsistent spacing in list. Use consistent ', ' format")
649
+
650
+ # Build result
651
+ result_parts = []
652
+
653
+ if errors:
654
+ result_parts.append("🚫 VALIDATION FAILED:")
655
+ for error in errors:
656
+ result_parts.append(f"❌ {error}")
657
+ result_parts.append("Fix issues then retry validation.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  else:
659
+ result_parts.append("✅ VALIDATION PASSED!")
660
+
661
+ if normalization_needed:
662
+ result_parts.append("\n💡 NORMALIZATION SUGGESTIONS:")
663
+ for suggestion in normalization_needed:
664
+ result_parts.append(f" • {suggestion}")
665
+
666
+ if warnings:
667
+ result_parts.append("\n⚠️ WARNINGS:")
668
+ for warning in warnings:
669
+ result_parts.append(f"⚠️ {warning}")
670
+ result_parts.append("Proceed if confident, or refine answer.")
671
+ else:
672
+ result_parts.append("Call final_answer_tool() now.")
673
+
674
+ result = "\n".join(result_parts)
675
 
676
  telemetry.record_call("validate_answer", time.time() - start_time, True)
677
  return result
678
 
679
  except Exception as e:
680
  telemetry.record_call("validate_answer", time.time() - start_time, False)
681
+ raise ToolError("validate_answer", e)
682
 
683
  # =============================================================================
684
  # CORE TOOLS
 
1930
 
1931
  @tool(args_schema=FinalAnswerInput)
1932
  def final_answer_tool(answer: str) -> str:
1933
+ """Submit final answer with normalization"""
1934
  start_time = time.time()
1935
 
1936
  try:
1937
+ # Get question from state (you'll need to pass this through)
1938
+ # For now, normalize without question context
1939
+ original_answer = answer
1940
+ answer = normalize_answer(answer)
1941
+
1942
+ if answer != original_answer:
1943
+ print(f"📝 Normalized answer:")
1944
+ print(f" Before: '{original_answer}'")
1945
+ print(f" After: '{answer}'")
1946
+
1947
+ print(f"\n✅ FINAL: '{answer}'\n")
1948
+
1949
  telemetry.record_call("final_answer_tool", time.time() - start_time, True)
1950
+ return f"FINAL_ANSWER: {answer}"
1951
+
1952
  except Exception as e:
1953
  telemetry.record_call("final_answer_tool", time.time() - start_time, False)
1954
+ raise ToolError("final_answer_tool", e)
1955
 
1956
 
1957
  # =============================================================================
 
2357
  self.current_llm = "groq"
2358
 
2359
  def prune_context_if_needed(state: AgentState) -> AgentState:
2360
+ """
2361
+ Prune conversation history if it's getting too long.
2362
+ Keeps system message + recent history to stay under token limits.
2363
+ """
2364
+ messages = state.get("messages", [])
2365
+
2366
+ # Keep first message (system prompt) + last N messages
2367
+ MAX_MESSAGES = 20 # Adjust based on your needs
2368
+
2369
+ if len(messages) > MAX_MESSAGES:
2370
+ print(f"⚠️ Context pruning: {len(messages)} messages → {MAX_MESSAGES}")
2371
+
2372
+ # Always keep system message (if it exists)
2373
+ system_msg = None
2374
+ if messages and isinstance(messages[0], SystemMessage):
2375
+ system_msg = messages[0]
2376
+ messages = messages[1:]
2377
+
2378
+ # Keep only recent messages
2379
+ recent_messages = messages[-(MAX_MESSAGES-1):]
2380
+
2381
+ # Reconstruct
2382
+ if system_msg:
2383
+ state["messages"] = [system_msg] + recent_messages
2384
+ else:
2385
+ state["messages"] = recent_messages
2386
+
2387
+ return state
2388
 
2389
  # Build agent graph
2390
  def agent_node(state: AgentState):