leonwoo's picture
Upload 22 files
1777acb verified
"""A LangGraph-based agent implementation."""
import re
import sys
import json
from pathlib import Path
from datetime import datetime
from langchain_core.messages import AIMessage, HumanMessage
from .graph import AgentState, build_agent_graph
def ensure_valid_answer(answer: str) -> str:
"""Ensure answer is never None or empty."""
if not answer or not isinstance(answer, str) or answer.strip() == "":
return "Unable to determine answer"
return answer.strip()
class TeeOutput:
"""Redirect stdout/stderr to both console and file."""
def __init__(self, file_path, mode='a'):
self.file = open(file_path, mode, encoding='utf-8')
self.terminal = sys.stdout if mode == 'a' else sys.stderr
def write(self, message):
self.terminal.write(message)
self.file.write(message)
self.file.flush()
def flush(self):
self.terminal.flush()
self.file.flush()
def close(self):
self.file.close()
class BasicAgent:
"""A LangGraph-powered agent that uses tools to answer questions."""
def __init__(self, log_to_file=True, use_cache=True, cache_file="agent_cache.json") -> None:
"""Initialize the agent with the compiled graph."""
self.graph = build_agent_graph()
self.log_file = None
self.use_cache = use_cache
self.cache_file = Path(cache_file)
self.answer_cache = {} # Cache for question -> answer mapping
# Load cache from disk if it exists
if self.use_cache:
self._load_cache()
# Set up logging to file
if log_to_file:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"agent_run_{timestamp}.log"
self.log_file = TeeOutput(log_filename, 'w')
sys.stdout = self.log_file
print(f"📝 Logging to: {log_filename}\n")
if self.use_cache and self.answer_cache:
print(f"💾 Loaded {len(self.answer_cache)} cached answers from {self.cache_file}\n")
def _load_cache(self):
"""Load answer cache from disk."""
try:
if self.cache_file.exists():
with open(self.cache_file, 'r', encoding='utf-8') as f:
self.answer_cache = json.load(f)
except Exception as e:
print(f"⚠️ Warning: Could not load cache from {self.cache_file}: {e}")
self.answer_cache = {}
def _save_cache(self):
"""Save answer cache to disk."""
try:
with open(self.cache_file, 'w', encoding='utf-8') as f:
json.dump(self.answer_cache, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"⚠️ Warning: Could not save cache to {self.cache_file}: {e}")
def _clean_answer(self, answer: str, question: str) -> str:
"""
Clean the answer based on GAIA scoring rules.
Aggressively removes explanatory text to provide only the literal answer.
"""
answer = answer.strip()
# Remove JSON formatting and code blocks
if answer.startswith('```'):
# Extract content from code blocks
lines = answer.split('\n')
answer = '\n'.join([l for l in lines if not l.startswith('```')])
answer = answer.strip()
# Remove JSON structures like {"name":"FINISH","answer":"value"}
if answer.startswith('{') and ('"name"' in answer or '"FINISH"' in answer):
try:
import json
# Try to parse as JSON
parsed = json.loads(answer)
# Extract the actual answer value from various possible keys
for key in ['answer', 'arguments', 'vegetables', 'surname', 'value', 'result', 'submitted_answer']:
if key in parsed and parsed[key] and parsed[key] != "FINISH":
answer = str(parsed[key])
break
# If still has "name" field, it's probably still JSON - extract any non-name value
if isinstance(parsed, dict) and 'name' in parsed:
for key, value in parsed.items():
if key != 'name' and key != 'FINISH' and value and value != "FINISH":
answer = str(value)
break
except:
pass
# Remove common prefixes and explanatory phrases
patterns_to_remove = [
r'^(the answer is|answer:|final answer:|thus,|therefore,|so,|hence,)\s*',
r'^(the\s+)?(correct\s+)?(number|city|country|name|value|total|result)\s+(is|are|was|were)\s*',
r'^\d+\.\s*', # Remove leading numbers like "1. " or "2. "
r'^[-•]\s*', # Remove bullet points
]
for pattern in patterns_to_remove:
answer = re.sub(pattern, '', answer, flags=re.IGNORECASE)
answer = answer.strip()
# If answer contains multiple sentences, try to extract just the key info
sentences = answer.split('.')
if len(sentences) > 1:
# Look for the shortest sentence that contains key info
for sent in sentences:
sent = sent.strip()
# If it's short and contains a number or key word, use it
if len(sent) < 50 and (any(char.isdigit() for char in sent) or len(sent.split()) <= 5):
answer = sent
break
# Remove trailing explanations in parentheses
answer = re.sub(r'\s*\([^)]*\)\s*$', '', answer)
# If the question asks for a comma-separated list, ensure no spaces after commas
if 'comma' in question.lower() and ('list' in question.lower() or 'separated' in question.lower()):
answer = re.sub(r',\s+', ',', answer)
# Clean numbers: remove currency symbols and commas
if len(answer.split()) <= 5: # Short answer, likely a number
if any(char.isdigit() for char in answer):
cleaned = answer
for symbol in ['$', '€', '£', '¥', '%', ',']:
cleaned = cleaned.replace(symbol, '')
# If after cleaning it's still a valid number, use the cleaned version
try:
float(cleaned.strip())
answer = cleaned.strip()
except ValueError:
pass # Not a pure number, keep original
# Final cleanup: remove quotes if they wrap the entire answer
answer = answer.strip('"\'')
return answer
def __call__(self, question: str) -> str:
"""Invoke the agent with a question and return the answer."""
try:
print("\n" + "="*80)
print(f"📋 QUESTION: {question[:150]}...")
print("="*80)
# Check cache first
if self.use_cache and question in self.answer_cache:
cached_answer = self.answer_cache[question]
print("\n💾 Using cached answer (no LLM call!)")
print(f"\n🎯 FINAL ANSWER: {cached_answer}")
print("="*80 + "\n")
return cached_answer
# Create the initial state with the user's question
state: AgentState = {"messages": [HumanMessage(content=question)]}
# Run the graph with increased recursion limit
print("\n🚀 Starting agent execution...")
result = self.graph.invoke(state, config={"recursion_limit": 50})
# Extract the final answer from the last AI message
for message in reversed(result["messages"]):
if isinstance(message, AIMessage):
raw_answer = message.content
# Clean the answer based on GAIA scoring rules
cleaned_answer = self._clean_answer(raw_answer, question)
# Ensure answer is never empty
validated_answer = ensure_valid_answer(cleaned_answer)
# Cache the answer and save to disk
if self.use_cache:
self.answer_cache[question] = validated_answer
self._save_cache() # Persist to disk immediately
print(f"\n🎯 FINAL ANSWER: {validated_answer}")
print("="*80 + "\n")
return validated_answer
print("\n⚠️ No answer found")
print("="*80 + "\n")
return ensure_valid_answer("")
except Exception as e:
print(f"\n❌ ERROR: {e}")
print("="*80 + "\n")
return ensure_valid_answer(f"Agent failed with error: {e}")