Spaces:
Runtime error
Runtime error
File size: 9,014 Bytes
1777acb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
"""A LangGraph-based agent implementation."""
import re
import sys
import json
from pathlib import Path
from datetime import datetime
from langchain_core.messages import AIMessage, HumanMessage
from .graph import AgentState, build_agent_graph
def ensure_valid_answer(answer: str) -> str:
"""Ensure answer is never None or empty."""
if not answer or not isinstance(answer, str) or answer.strip() == "":
return "Unable to determine answer"
return answer.strip()
class TeeOutput:
"""Redirect stdout/stderr to both console and file."""
def __init__(self, file_path, mode='a'):
self.file = open(file_path, mode, encoding='utf-8')
self.terminal = sys.stdout if mode == 'a' else sys.stderr
def write(self, message):
self.terminal.write(message)
self.file.write(message)
self.file.flush()
def flush(self):
self.terminal.flush()
self.file.flush()
def close(self):
self.file.close()
class BasicAgent:
"""A LangGraph-powered agent that uses tools to answer questions."""
def __init__(self, log_to_file=True, use_cache=True, cache_file="agent_cache.json") -> None:
"""Initialize the agent with the compiled graph."""
self.graph = build_agent_graph()
self.log_file = None
self.use_cache = use_cache
self.cache_file = Path(cache_file)
self.answer_cache = {} # Cache for question -> answer mapping
# Load cache from disk if it exists
if self.use_cache:
self._load_cache()
# Set up logging to file
if log_to_file:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"agent_run_{timestamp}.log"
self.log_file = TeeOutput(log_filename, 'w')
sys.stdout = self.log_file
print(f"📝 Logging to: {log_filename}\n")
if self.use_cache and self.answer_cache:
print(f"💾 Loaded {len(self.answer_cache)} cached answers from {self.cache_file}\n")
def _load_cache(self):
"""Load answer cache from disk."""
try:
if self.cache_file.exists():
with open(self.cache_file, 'r', encoding='utf-8') as f:
self.answer_cache = json.load(f)
except Exception as e:
print(f"⚠️ Warning: Could not load cache from {self.cache_file}: {e}")
self.answer_cache = {}
def _save_cache(self):
"""Save answer cache to disk."""
try:
with open(self.cache_file, 'w', encoding='utf-8') as f:
json.dump(self.answer_cache, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"⚠️ Warning: Could not save cache to {self.cache_file}: {e}")
def _clean_answer(self, answer: str, question: str) -> str:
"""
Clean the answer based on GAIA scoring rules.
Aggressively removes explanatory text to provide only the literal answer.
"""
answer = answer.strip()
# Remove JSON formatting and code blocks
if answer.startswith('```'):
# Extract content from code blocks
lines = answer.split('\n')
answer = '\n'.join([l for l in lines if not l.startswith('```')])
answer = answer.strip()
# Remove JSON structures like {"name":"FINISH","answer":"value"}
if answer.startswith('{') and ('"name"' in answer or '"FINISH"' in answer):
try:
import json
# Try to parse as JSON
parsed = json.loads(answer)
# Extract the actual answer value from various possible keys
for key in ['answer', 'arguments', 'vegetables', 'surname', 'value', 'result', 'submitted_answer']:
if key in parsed and parsed[key] and parsed[key] != "FINISH":
answer = str(parsed[key])
break
# If still has "name" field, it's probably still JSON - extract any non-name value
if isinstance(parsed, dict) and 'name' in parsed:
for key, value in parsed.items():
if key != 'name' and key != 'FINISH' and value and value != "FINISH":
answer = str(value)
break
except:
pass
# Remove common prefixes and explanatory phrases
patterns_to_remove = [
r'^(the answer is|answer:|final answer:|thus,|therefore,|so,|hence,)\s*',
r'^(the\s+)?(correct\s+)?(number|city|country|name|value|total|result)\s+(is|are|was|were)\s*',
r'^\d+\.\s*', # Remove leading numbers like "1. " or "2. "
r'^[-•]\s*', # Remove bullet points
]
for pattern in patterns_to_remove:
answer = re.sub(pattern, '', answer, flags=re.IGNORECASE)
answer = answer.strip()
# If answer contains multiple sentences, try to extract just the key info
sentences = answer.split('.')
if len(sentences) > 1:
# Look for the shortest sentence that contains key info
for sent in sentences:
sent = sent.strip()
# If it's short and contains a number or key word, use it
if len(sent) < 50 and (any(char.isdigit() for char in sent) or len(sent.split()) <= 5):
answer = sent
break
# Remove trailing explanations in parentheses
answer = re.sub(r'\s*\([^)]*\)\s*$', '', answer)
# If the question asks for a comma-separated list, ensure no spaces after commas
if 'comma' in question.lower() and ('list' in question.lower() or 'separated' in question.lower()):
answer = re.sub(r',\s+', ',', answer)
# Clean numbers: remove currency symbols and commas
if len(answer.split()) <= 5: # Short answer, likely a number
if any(char.isdigit() for char in answer):
cleaned = answer
for symbol in ['$', '€', '£', '¥', '%', ',']:
cleaned = cleaned.replace(symbol, '')
# If after cleaning it's still a valid number, use the cleaned version
try:
float(cleaned.strip())
answer = cleaned.strip()
except ValueError:
pass # Not a pure number, keep original
# Final cleanup: remove quotes if they wrap the entire answer
answer = answer.strip('"\'')
return answer
def __call__(self, question: str) -> str:
"""Invoke the agent with a question and return the answer."""
try:
print("\n" + "="*80)
print(f"📋 QUESTION: {question[:150]}...")
print("="*80)
# Check cache first
if self.use_cache and question in self.answer_cache:
cached_answer = self.answer_cache[question]
print("\n💾 Using cached answer (no LLM call!)")
print(f"\n🎯 FINAL ANSWER: {cached_answer}")
print("="*80 + "\n")
return cached_answer
# Create the initial state with the user's question
state: AgentState = {"messages": [HumanMessage(content=question)]}
# Run the graph with increased recursion limit
print("\n🚀 Starting agent execution...")
result = self.graph.invoke(state, config={"recursion_limit": 50})
# Extract the final answer from the last AI message
for message in reversed(result["messages"]):
if isinstance(message, AIMessage):
raw_answer = message.content
# Clean the answer based on GAIA scoring rules
cleaned_answer = self._clean_answer(raw_answer, question)
# Ensure answer is never empty
validated_answer = ensure_valid_answer(cleaned_answer)
# Cache the answer and save to disk
if self.use_cache:
self.answer_cache[question] = validated_answer
self._save_cache() # Persist to disk immediately
print(f"\n🎯 FINAL ANSWER: {validated_answer}")
print("="*80 + "\n")
return validated_answer
print("\n⚠️ No answer found")
print("="*80 + "\n")
return ensure_valid_answer("")
except Exception as e:
print(f"\n❌ ERROR: {e}")
print("="*80 + "\n")
return ensure_valid_answer(f"Agent failed with error: {e}")
|