Spaces:
Configuration error
Configuration error
| import re | |
| import logging | |
| from datetime import datetime | |
| from agents.supervisor import build_supervisor_graph | |
| logger = logging.getLogger("gaia_agent") | |
| _log_handler = logging.FileHandler("gaia_agent.log", mode="a") | |
| _log_handler.setFormatter(logging.Formatter("%(asctime)s | %(message)s", datefmt="%H:%M:%S")) | |
| logger.addHandler(_log_handler) | |
| logger.setLevel(logging.INFO) | |
| INTERNAL_ROUTING_PATTERNS = re.compile( | |
| r"^transfer_to_\w+$|^handoff_to_\w+$|^route_to_\w+$", re.IGNORECASE | |
| ) | |
| def _extract_answer(text: str) -> str: | |
| if not text: | |
| return "" | |
| if INTERNAL_ROUTING_PATTERNS.match(text.strip()): | |
| return "" | |
| # Look for "FINAL ANSWER: ..." pattern anywhere in the text | |
| fa_match = re.search(r"(?i)FINAL\s*ANSWER\s*:\s*(.+)", text) | |
| if fa_match: | |
| return fa_match.group(1).strip() | |
| # Fallback: strip common prefixes from the last non-empty line | |
| prefixes_to_strip = [ | |
| r"(?i)^the\s+answer\s+is\s*:\s*", | |
| r"(?i)^answer\s*:\s*", | |
| ] | |
| cleaned = text.strip() | |
| for pattern in prefixes_to_strip: | |
| cleaned = re.sub(pattern, "", cleaned).strip() | |
| lines = cleaned.strip().split("\n") | |
| if lines: | |
| last_non_empty = "" | |
| for line in reversed(lines): | |
| stripped = line.strip() | |
| if stripped and not INTERNAL_ROUTING_PATTERNS.match(stripped): | |
| last_non_empty = stripped | |
| break | |
| for pattern in prefixes_to_strip: | |
| last_non_empty = re.sub(pattern, "", last_non_empty).strip() | |
| if last_non_empty: | |
| cleaned = last_non_empty | |
| return cleaned.strip() | |
| def _extract_trace(messages) -> tuple[list[str], list[str]]: | |
| """Walk the message list and collect which agents and tools were invoked.""" | |
| agents_used = [] | |
| tools_used = [] | |
| for msg in messages: | |
| msg_type = type(msg).__name__ | |
| name = getattr(msg, "name", None) | |
| if msg_type == "AIMessage" and name and name != "supervisor": | |
| if name not in agents_used: | |
| agents_used.append(name) | |
| if msg_type == "ToolMessage" and name: | |
| if name not in tools_used: | |
| tools_used.append(name) | |
| return agents_used, tools_used | |
| class GAIAAgent: | |
| def __init__(self): | |
| print("Initializing GAIAAgent with multi-agent supervisor...") | |
| self.graph = build_supervisor_graph() | |
| logger.info("--- Session started ---") | |
| print("GAIAAgent initialized successfully.") | |
| def __call__(self, question: str, task_id: str | None = None, file_name: str = "") -> str: | |
| print(f"\n{'='*60}") | |
| print(f"Question (first 100 chars): {question[:100]}...") | |
| print(f"Task ID: {task_id}") | |
| has_file = bool(file_name) | |
| print(f"Associated file: {'yes (' + file_name + ')' if has_file else 'no'}") | |
| prompt = question | |
| if has_file and task_id: | |
| prompt = ( | |
| f"{question}\n\n" | |
| f"[IMPORTANT CONTEXT: This question has an associated file named '{file_name}'. " | |
| f"You MUST use the download_gaia_file tool with task_id='{task_id}' and " | |
| f"file_name='{file_name}' to download and process this file before answering.]" | |
| ) | |
| elif task_id: | |
| prompt = ( | |
| f"{question}\n\n" | |
| f"[Context: Task ID is '{task_id}'. If you need to download an associated file, " | |
| f"use the download_gaia_file tool with this task_id.]" | |
| ) | |
| messages = [{"role": "user", "content": prompt}] | |
| try: | |
| result = self.graph.invoke( | |
| {"messages": messages}, | |
| config={"recursion_limit": 50}, | |
| ) | |
| response_messages = result.get("messages", []) | |
| agents_used, tools_used = _extract_trace(response_messages) | |
| if response_messages: | |
| final_msg = response_messages[-1] | |
| raw_answer = ( | |
| final_msg.content | |
| if hasattr(final_msg, "content") | |
| else str(final_msg) | |
| ) | |
| else: | |
| raw_answer = str(result) | |
| answer = _extract_answer(raw_answer) | |
| logger.info( | |
| f"Q: {question[:80]}... | " | |
| f"file={'yes' if has_file else 'no'} | " | |
| f"agents: {', '.join(agents_used) or 'none'} | " | |
| f"tools: {', '.join(tools_used) or 'none'} | " | |
| f"answer: {answer[:80]}" | |
| ) | |
| print(f"Agents used: {agents_used}") | |
| print(f"Tools used: {tools_used}") | |
| print(f"Final answer: {answer}") | |
| print(f"{'='*60}\n") | |
| return answer | |
| except Exception as e: | |
| print(f"Error running agent: {e}") | |
| logger.info(f"Q: {question[:80]}... | ERROR: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error: {e}" | |