File size: 4,973 Bytes
95bd81e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import re
import logging
from datetime import datetime
from agents.supervisor import build_supervisor_graph

logger = logging.getLogger("gaia_agent")
_log_handler = logging.FileHandler("gaia_agent.log", mode="a")
_log_handler.setFormatter(logging.Formatter("%(asctime)s | %(message)s", datefmt="%H:%M:%S"))
logger.addHandler(_log_handler)
logger.setLevel(logging.INFO)

INTERNAL_ROUTING_PATTERNS = re.compile(
    r"^transfer_to_\w+$|^handoff_to_\w+$|^route_to_\w+$", re.IGNORECASE
)


def _extract_answer(text: str) -> str:
    if not text:
        return ""
    if INTERNAL_ROUTING_PATTERNS.match(text.strip()):
        return ""

    # Look for "FINAL ANSWER: ..." pattern anywhere in the text
    fa_match = re.search(r"(?i)FINAL\s*ANSWER\s*:\s*(.+)", text)
    if fa_match:
        return fa_match.group(1).strip()

    # Fallback: strip common prefixes from the last non-empty line
    prefixes_to_strip = [
        r"(?i)^the\s+answer\s+is\s*:\s*",
        r"(?i)^answer\s*:\s*",
    ]
    cleaned = text.strip()
    for pattern in prefixes_to_strip:
        cleaned = re.sub(pattern, "", cleaned).strip()

    lines = cleaned.strip().split("\n")
    if lines:
        last_non_empty = ""
        for line in reversed(lines):
            stripped = line.strip()
            if stripped and not INTERNAL_ROUTING_PATTERNS.match(stripped):
                last_non_empty = stripped
                break
        for pattern in prefixes_to_strip:
            last_non_empty = re.sub(pattern, "", last_non_empty).strip()
        if last_non_empty:
            cleaned = last_non_empty

    return cleaned.strip()


def _extract_trace(messages) -> tuple[list[str], list[str]]:
    """Walk the message list and collect which agents and tools were invoked."""
    agents_used = []
    tools_used = []
    for msg in messages:
        msg_type = type(msg).__name__
        name = getattr(msg, "name", None)
        if msg_type == "AIMessage" and name and name != "supervisor":
            if name not in agents_used:
                agents_used.append(name)
        if msg_type == "ToolMessage" and name:
            if name not in tools_used:
                tools_used.append(name)
    return agents_used, tools_used


class GAIAAgent:
    def __init__(self):
        print("Initializing GAIAAgent with multi-agent supervisor...")
        self.graph = build_supervisor_graph()
        logger.info("--- Session started ---")
        print("GAIAAgent initialized successfully.")

    def __call__(self, question: str, task_id: str | None = None, file_name: str = "") -> str:
        print(f"\n{'='*60}")
        print(f"Question (first 100 chars): {question[:100]}...")
        print(f"Task ID: {task_id}")

        has_file = bool(file_name)
        print(f"Associated file: {'yes (' + file_name + ')' if has_file else 'no'}")

        prompt = question
        if has_file and task_id:
            prompt = (
                f"{question}\n\n"
                f"[IMPORTANT CONTEXT: This question has an associated file named '{file_name}'. "
                f"You MUST use the download_gaia_file tool with task_id='{task_id}' and "
                f"file_name='{file_name}' to download and process this file before answering.]"
            )
        elif task_id:
            prompt = (
                f"{question}\n\n"
                f"[Context: Task ID is '{task_id}'. If you need to download an associated file, "
                f"use the download_gaia_file tool with this task_id.]"
            )

        messages = [{"role": "user", "content": prompt}]

        try:
            result = self.graph.invoke(
                {"messages": messages},
                config={"recursion_limit": 50},
            )

            response_messages = result.get("messages", [])

            agents_used, tools_used = _extract_trace(response_messages)

            if response_messages:
                final_msg = response_messages[-1]
                raw_answer = (
                    final_msg.content
                    if hasattr(final_msg, "content")
                    else str(final_msg)
                )
            else:
                raw_answer = str(result)

            answer = _extract_answer(raw_answer)

            logger.info(
                f"Q: {question[:80]}... | "
                f"file={'yes' if has_file else 'no'} | "
                f"agents: {', '.join(agents_used) or 'none'} | "
                f"tools: {', '.join(tools_used) or 'none'} | "
                f"answer: {answer[:80]}"
            )

            print(f"Agents used: {agents_used}")
            print(f"Tools used: {tools_used}")
            print(f"Final answer: {answer}")
            print(f"{'='*60}\n")
            return answer

        except Exception as e:
            print(f"Error running agent: {e}")
            logger.info(f"Q: {question[:80]}... | ERROR: {e}")
            import traceback
            traceback.print_exc()
            return f"Error: {e}"