Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -97,6 +97,80 @@ def retry_with_backoff(max_retries=None, base_delay=None):
|
|
| 97 |
return wrapper
|
| 98 |
return decorator
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
class SearchCache:
|
| 102 |
"""LRU cache for search results"""
|
|
@@ -517,120 +591,94 @@ class ValidateInput(BaseModel):
|
|
| 517 |
proposed_answer: str = Field(description="Answer to validate")
|
| 518 |
original_question: str = Field(description="Original question (first 100 chars)")
|
| 519 |
|
| 520 |
-
@tool(args_schema=
|
| 521 |
-
def validate_answer(
|
| 522 |
"""
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
ALWAYS use before final_answer_tool.
|
| 526 |
"""
|
| 527 |
start_time = time.time()
|
|
|
|
| 528 |
try:
|
| 529 |
-
print(f"✓ Validating: '{
|
| 530 |
|
| 531 |
-
issues = []
|
| 532 |
warnings = []
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
#
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
#
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
issues.append("❌ Remove excessive punctuation")
|
| 586 |
-
|
| 587 |
-
# 9. Check for quotes around answer
|
| 588 |
-
if (proposed_answer.startswith('"') and proposed_answer.endswith('"')) or \
|
| 589 |
-
(proposed_answer.startswith("'") and proposed_answer.endswith("'")):
|
| 590 |
-
suggestions.append("💡 Consider removing quotes around answer")
|
| 591 |
-
|
| 592 |
-
# 10. Check for multiple sentences when one expected
|
| 593 |
-
sentences = [s.strip() for s in proposed_answer.split('.') if s.strip()]
|
| 594 |
-
if len(sentences) > 3:
|
| 595 |
-
if not any(k in question_lower for k in ['explain', 'describe', 'why', 'how']):
|
| 596 |
-
warnings.append("⚠️ Multiple sentences. Question asks for simple answer?")
|
| 597 |
-
|
| 598 |
-
# 11. Sanity check: is it empty?
|
| 599 |
-
if not proposed_answer.strip():
|
| 600 |
-
issues.append("❌ Answer is empty!")
|
| 601 |
-
|
| 602 |
-
# 12. Check for units in measurement questions
|
| 603 |
-
unit_keywords = ['height', 'weight', 'distance', 'speed', 'temperature', 'size']
|
| 604 |
-
if any(k in question_lower for k in unit_keywords):
|
| 605 |
-
has_unit = any(u in proposed_answer.lower() for u in
|
| 606 |
-
['km', 'miles', 'kg', 'lbs', 'cm', 'inches', 'celsius',
|
| 607 |
-
'fahrenheit', 'mph', 'kph', 'meters', 'feet'])
|
| 608 |
-
if not has_unit and any(c.isdigit() for c in proposed_answer):
|
| 609 |
-
warnings.append("⚠️ Measurement question but no unit found")
|
| 610 |
-
|
| 611 |
-
# Build response
|
| 612 |
-
if issues:
|
| 613 |
-
result = "🚫 VALIDATION FAILED:\n" + "\n".join(issues)
|
| 614 |
-
if suggestions:
|
| 615 |
-
result += "\n\nSuggestions:\n" + "\n".join(suggestions)
|
| 616 |
-
result += "\n\nFix issues then retry validation."
|
| 617 |
-
elif warnings:
|
| 618 |
-
result = "⚠️ WARNINGS:\n" + "\n".join(warnings)
|
| 619 |
-
if suggestions:
|
| 620 |
-
result += "\n\nSuggestions:\n" + "\n".join(suggestions)
|
| 621 |
-
result += "\n\nProceed if confident, or refine answer."
|
| 622 |
-
elif suggestions:
|
| 623 |
-
result = "✅ PASSED with suggestions:\n" + "\n".join(suggestions)
|
| 624 |
-
result += "\n\nCall final_answer_tool() when ready."
|
| 625 |
else:
|
| 626 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
|
| 628 |
telemetry.record_call("validate_answer", time.time() - start_time, True)
|
| 629 |
return result
|
| 630 |
|
| 631 |
except Exception as e:
|
| 632 |
telemetry.record_call("validate_answer", time.time() - start_time, False)
|
| 633 |
-
raise
|
| 634 |
|
| 635 |
# =============================================================================
|
| 636 |
# CORE TOOLS
|
|
@@ -1882,16 +1930,28 @@ class FinalAnswerInput(BaseModel):
|
|
| 1882 |
|
| 1883 |
@tool(args_schema=FinalAnswerInput)
|
| 1884 |
def final_answer_tool(answer: str) -> str:
|
| 1885 |
-
"""Submit final answer"""
|
| 1886 |
start_time = time.time()
|
| 1887 |
|
| 1888 |
try:
|
| 1889 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1890 |
telemetry.record_call("final_answer_tool", time.time() - start_time, True)
|
| 1891 |
-
return answer
|
|
|
|
| 1892 |
except Exception as e:
|
| 1893 |
telemetry.record_call("final_answer_tool", time.time() - start_time, False)
|
| 1894 |
-
raise
|
| 1895 |
|
| 1896 |
|
| 1897 |
# =============================================================================
|
|
@@ -2297,34 +2357,34 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format.
|
|
| 2297 |
self.current_llm = "groq"
|
| 2298 |
|
| 2299 |
def prune_context_if_needed(state: AgentState) -> AgentState:
|
| 2300 |
-
|
| 2301 |
-
|
| 2302 |
-
|
| 2303 |
-
|
| 2304 |
-
|
| 2305 |
-
|
| 2306 |
-
|
| 2307 |
-
|
| 2308 |
-
|
| 2309 |
-
|
| 2310 |
-
|
| 2311 |
-
|
| 2312 |
-
|
| 2313 |
-
|
| 2314 |
-
|
| 2315 |
-
|
| 2316 |
-
|
| 2317 |
-
|
| 2318 |
-
|
| 2319 |
-
|
| 2320 |
-
|
| 2321 |
-
|
| 2322 |
-
|
| 2323 |
-
|
| 2324 |
-
|
| 2325 |
-
|
| 2326 |
-
|
| 2327 |
-
|
| 2328 |
|
| 2329 |
# Build agent graph
|
| 2330 |
def agent_node(state: AgentState):
|
|
|
|
| 97 |
return wrapper
|
| 98 |
return decorator
|
| 99 |
|
| 100 |
+
def normalize_answer(answer: str, question: str = "") -> str:
|
| 101 |
+
"""
|
| 102 |
+
Normalize answer to match expected format.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
answer: The answer to normalize
|
| 106 |
+
question: Optional question text to determine if order matters
|
| 107 |
+
"""
|
| 108 |
+
if not answer:
|
| 109 |
+
return answer
|
| 110 |
+
|
| 111 |
+
original = answer
|
| 112 |
+
answer = answer.strip()
|
| 113 |
+
|
| 114 |
+
# Remove common prefixes
|
| 115 |
+
prefixes_to_remove = [
|
| 116 |
+
"the answer is:",
|
| 117 |
+
"the answer is",
|
| 118 |
+
"answer:",
|
| 119 |
+
"final answer:",
|
| 120 |
+
"result:",
|
| 121 |
+
]
|
| 122 |
+
for prefix in prefixes_to_remove:
|
| 123 |
+
if answer.lower().startswith(prefix):
|
| 124 |
+
answer = answer[len(prefix):].strip()
|
| 125 |
+
|
| 126 |
+
# Handle lists
|
| 127 |
+
if "," in answer:
|
| 128 |
+
items = [item.strip() for item in answer.split(",")]
|
| 129 |
+
items = [item for item in items if item]
|
| 130 |
+
|
| 131 |
+
# Determine if order matters based on question
|
| 132 |
+
order_matters_keywords = [
|
| 133 |
+
"first", "last", "before", "after", "sequence",
|
| 134 |
+
"order", "chronological", "oldest", "newest",
|
| 135 |
+
"in the form", "format"
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
order_matters = any(kw in question.lower() for kw in order_matters_keywords)
|
| 139 |
+
|
| 140 |
+
if not order_matters:
|
| 141 |
+
# Sort alphabetically for consistency
|
| 142 |
+
items.sort()
|
| 143 |
+
print(f" 📋 Sorted list alphabetically (order doesn't seem to matter)")
|
| 144 |
+
else:
|
| 145 |
+
print(f" 📋 Kept original order (question specifies order)")
|
| 146 |
+
|
| 147 |
+
# Normalize each item
|
| 148 |
+
items = [item.strip().rstrip('.') for item in items]
|
| 149 |
+
|
| 150 |
+
# Consistent spacing
|
| 151 |
+
answer = ", ".join(items)
|
| 152 |
+
|
| 153 |
+
# Single word capitalization
|
| 154 |
+
if len(answer.split()) == 1:
|
| 155 |
+
if answer.lower() in ['right', 'left', 'yes', 'no', 'true', 'false']:
|
| 156 |
+
answer = answer.capitalize()
|
| 157 |
+
|
| 158 |
+
# Handle "St." vs "Saint"
|
| 159 |
+
if "without abbreviations" in question.lower():
|
| 160 |
+
answer = answer.replace("St.", "Saint")
|
| 161 |
+
answer = answer.replace("Dr.", "Doctor")
|
| 162 |
+
answer = answer.replace("Mt.", "Mount")
|
| 163 |
+
|
| 164 |
+
# Remove trailing period (unless decimal)
|
| 165 |
+
if answer.endswith('.') and not (len(answer) > 1 and answer[-2].isdigit()):
|
| 166 |
+
answer = answer[:-1]
|
| 167 |
+
|
| 168 |
+
# Remove wrapping quotes
|
| 169 |
+
if (answer.startswith('"') and answer.endswith('"')) or \
|
| 170 |
+
(answer.startswith("'") and answer.endswith("'")):
|
| 171 |
+
answer = answer[1:-1]
|
| 172 |
+
|
| 173 |
+
return answer
|
| 174 |
|
| 175 |
class SearchCache:
|
| 176 |
"""LRU cache for search results"""
|
|
|
|
| 591 |
proposed_answer: str = Field(description="Answer to validate")
|
| 592 |
original_question: str = Field(description="Original question (first 100 chars)")
|
| 593 |
|
| 594 |
+
@tool(args_schema=ValidateAnswerInput)
|
| 595 |
+
def validate_answer(answer: str) -> str:
|
| 596 |
"""
|
| 597 |
+
Validate answer format and provide warnings.
|
| 598 |
+
Returns validation result with normalization suggestions.
|
|
|
|
| 599 |
"""
|
| 600 |
start_time = time.time()
|
| 601 |
+
|
| 602 |
try:
|
| 603 |
+
print(f"✓ Validating: '{answer[:50]}...'")
|
| 604 |
|
|
|
|
| 605 |
warnings = []
|
| 606 |
+
errors = []
|
| 607 |
+
normalization_needed = []
|
| 608 |
+
|
| 609 |
+
# Normalize for validation
|
| 610 |
+
normalized = normalize_answer(answer)
|
| 611 |
+
|
| 612 |
+
if normalized != answer:
|
| 613 |
+
normalization_needed.append(f"Consider using normalized form: '{normalized}'")
|
| 614 |
+
|
| 615 |
+
# Check 1: Empty answer
|
| 616 |
+
if not answer or not answer.strip():
|
| 617 |
+
errors.append("Answer is empty")
|
| 618 |
+
|
| 619 |
+
# Check 2: Too long (probably explaining instead of answering)
|
| 620 |
+
if len(answer) > 200:
|
| 621 |
+
warnings.append("Answer is very long (>200 chars). Consider if question asks for brief response.")
|
| 622 |
+
|
| 623 |
+
# Check 3: Contains question words
|
| 624 |
+
question_words = ['what', 'who', 'when', 'where', 'why', 'how', 'which']
|
| 625 |
+
if any(word in answer.lower() for word in question_words):
|
| 626 |
+
warnings.append("Answer contains question words. Make sure you're providing the answer, not rephrasing the question.")
|
| 627 |
+
|
| 628 |
+
# Check 4: List ordering
|
| 629 |
+
if "," in answer:
|
| 630 |
+
items = [item.strip() for item in answer.split(",")]
|
| 631 |
+
if len(items) > 1:
|
| 632 |
+
warnings.append(f"List detected with {len(items)} items. Verify order matches question requirements.")
|
| 633 |
+
|
| 634 |
+
# Check 5: Capitalization consistency
|
| 635 |
+
if answer.lower() in ['right', 'left', 'yes', 'no', 'true', 'false']:
|
| 636 |
+
if not answer[0].isupper():
|
| 637 |
+
normalization_needed.append(f"Consider capitalizing: '{answer.capitalize()}'")
|
| 638 |
+
|
| 639 |
+
# Check 6: Abbreviations
|
| 640 |
+
if any(abbrev in answer.lower() for abbrev in ['st.', 'dr.', 'mt.']):
|
| 641 |
+
if "without abbreviations" in str(answer).lower() or "full" in str(answer).lower():
|
| 642 |
+
warnings.append("Question may ask for full form without abbreviations")
|
| 643 |
+
|
| 644 |
+
# Check 7: Spacing in lists
|
| 645 |
+
if "," in answer:
|
| 646 |
+
# Check for inconsistent spacing
|
| 647 |
+
if ", " in answer and "," in answer.replace(", ", ""):
|
| 648 |
+
normalization_needed.append("Inconsistent spacing in list. Use consistent ', ' format")
|
| 649 |
+
|
| 650 |
+
# Build result
|
| 651 |
+
result_parts = []
|
| 652 |
+
|
| 653 |
+
if errors:
|
| 654 |
+
result_parts.append("🚫 VALIDATION FAILED:")
|
| 655 |
+
for error in errors:
|
| 656 |
+
result_parts.append(f"❌ {error}")
|
| 657 |
+
result_parts.append("Fix issues then retry validation.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
else:
|
| 659 |
+
result_parts.append("✅ VALIDATION PASSED!")
|
| 660 |
+
|
| 661 |
+
if normalization_needed:
|
| 662 |
+
result_parts.append("\n💡 NORMALIZATION SUGGESTIONS:")
|
| 663 |
+
for suggestion in normalization_needed:
|
| 664 |
+
result_parts.append(f" • {suggestion}")
|
| 665 |
+
|
| 666 |
+
if warnings:
|
| 667 |
+
result_parts.append("\n⚠️ WARNINGS:")
|
| 668 |
+
for warning in warnings:
|
| 669 |
+
result_parts.append(f"⚠️ {warning}")
|
| 670 |
+
result_parts.append("Proceed if confident, or refine answer.")
|
| 671 |
+
else:
|
| 672 |
+
result_parts.append("Call final_answer_tool() now.")
|
| 673 |
+
|
| 674 |
+
result = "\n".join(result_parts)
|
| 675 |
|
| 676 |
telemetry.record_call("validate_answer", time.time() - start_time, True)
|
| 677 |
return result
|
| 678 |
|
| 679 |
except Exception as e:
|
| 680 |
telemetry.record_call("validate_answer", time.time() - start_time, False)
|
| 681 |
+
raise ToolError("validate_answer", e)
|
| 682 |
|
| 683 |
# =============================================================================
|
| 684 |
# CORE TOOLS
|
|
|
|
| 1930 |
|
| 1931 |
@tool(args_schema=FinalAnswerInput)
|
| 1932 |
def final_answer_tool(answer: str) -> str:
|
| 1933 |
+
"""Submit final answer with normalization"""
|
| 1934 |
start_time = time.time()
|
| 1935 |
|
| 1936 |
try:
|
| 1937 |
+
# Get question from state (you'll need to pass this through)
|
| 1938 |
+
# For now, normalize without question context
|
| 1939 |
+
original_answer = answer
|
| 1940 |
+
answer = normalize_answer(answer)
|
| 1941 |
+
|
| 1942 |
+
if answer != original_answer:
|
| 1943 |
+
print(f"📝 Normalized answer:")
|
| 1944 |
+
print(f" Before: '{original_answer}'")
|
| 1945 |
+
print(f" After: '{answer}'")
|
| 1946 |
+
|
| 1947 |
+
print(f"\n✅ FINAL: '{answer}'\n")
|
| 1948 |
+
|
| 1949 |
telemetry.record_call("final_answer_tool", time.time() - start_time, True)
|
| 1950 |
+
return f"FINAL_ANSWER: {answer}"
|
| 1951 |
+
|
| 1952 |
except Exception as e:
|
| 1953 |
telemetry.record_call("final_answer_tool", time.time() - start_time, False)
|
| 1954 |
+
raise ToolError("final_answer_tool", e)
|
| 1955 |
|
| 1956 |
|
| 1957 |
# =============================================================================
|
|
|
|
| 2357 |
self.current_llm = "groq"
|
| 2358 |
|
| 2359 |
def prune_context_if_needed(state: AgentState) -> AgentState:
|
| 2360 |
+
"""
|
| 2361 |
+
Prune conversation history if it's getting too long.
|
| 2362 |
+
Keeps system message + recent history to stay under token limits.
|
| 2363 |
+
"""
|
| 2364 |
+
messages = state.get("messages", [])
|
| 2365 |
+
|
| 2366 |
+
# Keep first message (system prompt) + last N messages
|
| 2367 |
+
MAX_MESSAGES = 20 # Adjust based on your needs
|
| 2368 |
+
|
| 2369 |
+
if len(messages) > MAX_MESSAGES:
|
| 2370 |
+
print(f"⚠️ Context pruning: {len(messages)} messages → {MAX_MESSAGES}")
|
| 2371 |
+
|
| 2372 |
+
# Always keep system message (if it exists)
|
| 2373 |
+
system_msg = None
|
| 2374 |
+
if messages and isinstance(messages[0], SystemMessage):
|
| 2375 |
+
system_msg = messages[0]
|
| 2376 |
+
messages = messages[1:]
|
| 2377 |
+
|
| 2378 |
+
# Keep only recent messages
|
| 2379 |
+
recent_messages = messages[-(MAX_MESSAGES-1):]
|
| 2380 |
+
|
| 2381 |
+
# Reconstruct
|
| 2382 |
+
if system_msg:
|
| 2383 |
+
state["messages"] = [system_msg] + recent_messages
|
| 2384 |
+
else:
|
| 2385 |
+
state["messages"] = recent_messages
|
| 2386 |
+
|
| 2387 |
+
return state
|
| 2388 |
|
| 2389 |
# Build agent graph
|
| 2390 |
def agent_node(state: AgentState):
|