"""A LangGraph-based agent implementation.""" import re import sys import json from pathlib import Path from datetime import datetime from langchain_core.messages import AIMessage, HumanMessage from .graph import AgentState, build_agent_graph def ensure_valid_answer(answer: str) -> str: """Ensure answer is never None or empty.""" if not answer or not isinstance(answer, str) or answer.strip() == "": return "Unable to determine answer" return answer.strip() class TeeOutput: """Redirect stdout/stderr to both console and file.""" def __init__(self, file_path, mode='a'): self.file = open(file_path, mode, encoding='utf-8') self.terminal = sys.stdout if mode == 'a' else sys.stderr def write(self, message): self.terminal.write(message) self.file.write(message) self.file.flush() def flush(self): self.terminal.flush() self.file.flush() def close(self): self.file.close() class BasicAgent: """A LangGraph-powered agent that uses tools to answer questions.""" def __init__(self, log_to_file=True, use_cache=True, cache_file="agent_cache.json") -> None: """Initialize the agent with the compiled graph.""" self.graph = build_agent_graph() self.log_file = None self.use_cache = use_cache self.cache_file = Path(cache_file) self.answer_cache = {} # Cache for question -> answer mapping # Load cache from disk if it exists if self.use_cache: self._load_cache() # Set up logging to file if log_to_file: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_filename = f"agent_run_{timestamp}.log" self.log_file = TeeOutput(log_filename, 'w') sys.stdout = self.log_file print(f"📝 Logging to: {log_filename}\n") if self.use_cache and self.answer_cache: print(f"💾 Loaded {len(self.answer_cache)} cached answers from {self.cache_file}\n") def _load_cache(self): """Load answer cache from disk.""" try: if self.cache_file.exists(): with open(self.cache_file, 'r', encoding='utf-8') as f: self.answer_cache = json.load(f) except Exception as e: print(f"⚠️ Warning: Could not load cache from {self.cache_file}: {e}") self.answer_cache = {} def _save_cache(self): """Save answer cache to disk.""" try: with open(self.cache_file, 'w', encoding='utf-8') as f: json.dump(self.answer_cache, f, indent=2, ensure_ascii=False) except Exception as e: print(f"⚠️ Warning: Could not save cache to {self.cache_file}: {e}") def _clean_answer(self, answer: str, question: str) -> str: """ Clean the answer based on GAIA scoring rules. Aggressively removes explanatory text to provide only the literal answer. """ answer = answer.strip() # Remove JSON formatting and code blocks if answer.startswith('```'): # Extract content from code blocks lines = answer.split('\n') answer = '\n'.join([l for l in lines if not l.startswith('```')]) answer = answer.strip() # Remove JSON structures like {"name":"FINISH","answer":"value"} if answer.startswith('{') and ('"name"' in answer or '"FINISH"' in answer): try: import json # Try to parse as JSON parsed = json.loads(answer) # Extract the actual answer value from various possible keys for key in ['answer', 'arguments', 'vegetables', 'surname', 'value', 'result', 'submitted_answer']: if key in parsed and parsed[key] and parsed[key] != "FINISH": answer = str(parsed[key]) break # If still has "name" field, it's probably still JSON - extract any non-name value if isinstance(parsed, dict) and 'name' in parsed: for key, value in parsed.items(): if key != 'name' and key != 'FINISH' and value and value != "FINISH": answer = str(value) break except: pass # Remove common prefixes and explanatory phrases patterns_to_remove = [ r'^(the answer is|answer:|final answer:|thus,|therefore,|so,|hence,)\s*', r'^(the\s+)?(correct\s+)?(number|city|country|name|value|total|result)\s+(is|are|was|were)\s*', r'^\d+\.\s*', # Remove leading numbers like "1. " or "2. " r'^[-•]\s*', # Remove bullet points ] for pattern in patterns_to_remove: answer = re.sub(pattern, '', answer, flags=re.IGNORECASE) answer = answer.strip() # If answer contains multiple sentences, try to extract just the key info sentences = answer.split('.') if len(sentences) > 1: # Look for the shortest sentence that contains key info for sent in sentences: sent = sent.strip() # If it's short and contains a number or key word, use it if len(sent) < 50 and (any(char.isdigit() for char in sent) or len(sent.split()) <= 5): answer = sent break # Remove trailing explanations in parentheses answer = re.sub(r'\s*\([^)]*\)\s*$', '', answer) # If the question asks for a comma-separated list, ensure no spaces after commas if 'comma' in question.lower() and ('list' in question.lower() or 'separated' in question.lower()): answer = re.sub(r',\s+', ',', answer) # Clean numbers: remove currency symbols and commas if len(answer.split()) <= 5: # Short answer, likely a number if any(char.isdigit() for char in answer): cleaned = answer for symbol in ['$', '€', '£', '¥', '%', ',']: cleaned = cleaned.replace(symbol, '') # If after cleaning it's still a valid number, use the cleaned version try: float(cleaned.strip()) answer = cleaned.strip() except ValueError: pass # Not a pure number, keep original # Final cleanup: remove quotes if they wrap the entire answer answer = answer.strip('"\'') return answer def __call__(self, question: str) -> str: """Invoke the agent with a question and return the answer.""" try: print("\n" + "="*80) print(f"📋 QUESTION: {question[:150]}...") print("="*80) # Check cache first if self.use_cache and question in self.answer_cache: cached_answer = self.answer_cache[question] print("\n💾 Using cached answer (no LLM call!)") print(f"\n🎯 FINAL ANSWER: {cached_answer}") print("="*80 + "\n") return cached_answer # Create the initial state with the user's question state: AgentState = {"messages": [HumanMessage(content=question)]} # Run the graph with increased recursion limit print("\n🚀 Starting agent execution...") result = self.graph.invoke(state, config={"recursion_limit": 50}) # Extract the final answer from the last AI message for message in reversed(result["messages"]): if isinstance(message, AIMessage): raw_answer = message.content # Clean the answer based on GAIA scoring rules cleaned_answer = self._clean_answer(raw_answer, question) # Ensure answer is never empty validated_answer = ensure_valid_answer(cleaned_answer) # Cache the answer and save to disk if self.use_cache: self.answer_cache[question] = validated_answer self._save_cache() # Persist to disk immediately print(f"\n🎯 FINAL ANSWER: {validated_answer}") print("="*80 + "\n") return validated_answer print("\n⚠️ No answer found") print("="*80 + "\n") return ensure_valid_answer("") except Exception as e: print(f"\n❌ ERROR: {e}") print("="*80 + "\n") return ensure_valid_answer(f"Agent failed with error: {e}")