File size: 7,284 Bytes
9bc6cdb
 
a437070
2debc20
9bc6cdb
e2a7e20
9bc6cdb
3c18169
2debc20
7db7bee
2debc20
 
1b356ff
 
 
 
e2a7e20
 
d4518ca
e2a7e20
a37949b
5404812
640e948
 
2debc20
9bc6cdb
 
 
7db7bee
b8c3c10
 
 
 
 
 
7db7bee
b8c3c10
 
 
7db7bee
 
 
 
b8c3c10
fa470ab
b8c3c10
 
 
 
fa470ab
b8c3c10
 
7db7bee
2debc20
 
9bc6cdb
a813a8a
 
 
 
 
 
 
8f493ce
 
da6caa3
 
e79a6e2
21444f7
e79a6e2
8b20388
e79a6e2
8f493ce
 
a813a8a
 
 
f75c084
 
 
 
 
 
 
 
 
 
 
 
9bc6cdb
f75c084
9bc6cdb
 
2debc20
6cb286c
9bc6cdb
 
2debc20
9bc6cdb
 
2debc20
9bc6cdb
2debc20
9bc6cdb
2debc20
9bc6cdb
 
6cb286c
0a3b4fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a502d76
 
 
 
 
 
 
 
 
 
0a3b4fb
 
 
 
 
 
 
 
b22532d
0a3b4fb
9bc6cdb
2debc20
 
7db7bee
 
 
 
 
 
 
 
 
 
 
 
c47f78f
9bc6cdb
 
c47f78f
 
 
 
 
 
 
 
 
9bc6cdb
ac0b6fe
 
 
dd22060
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from typing import Annotated, TypedDict
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage, AnyMessage, SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import START, StateGraph
from langchain_openai import ChatOpenAI
from tools import all_tools
import inspect
import os
import re

# 1. Setup once
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("Missing OPENAI_API_KEY environment variable.")

chat = ChatOpenAI(
    model="gpt-3.5-turbo",
    openai_api_key=OPENAI_API_KEY,
    temperature=0,
)

chat_with_tools = chat.bind_tools(all_tools)

# 2. Define the agent state
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

def extract_gaia_answer(text: str) -> str:
    """
    Extracts just the final answer in raw form, stripping explanation and prefixes like:
    - 'The answer is: ...'
    - 'Answer: ...'
    - Or just the raw line if short and valid.
    """
    patterns = [
        r"The answer is:\s*(.+)",
        r"Answer:\s*(.+)",
        r"^([a-z0-9\s,\-]+)$",  # simple raw line (numbers, text)
    ]
    for pattern in patterns:
        match = re.search(pattern, text.strip(), re.IGNORECASE | re.MULTILINE)
        if match:
            return match.group(1).strip().lower()

    # Fallback: return first short line if it's probably the answer
    lines = [l.strip() for l in text.strip().splitlines() if l.strip()]
    if lines and len(lines[0]) < 80:
        return lines[0].strip().lower()

    # Final fallback: return full text, lowercase
    return text.strip().lower()

# 3. Assistant node

def assistant(state: AgentState):
    tool_descriptions = "\n".join([
        f"{tool.name}{inspect.signature(tool.func)}:\n    {tool.description.strip()}"
        for tool in all_tools
    ])

    sys_msg = SystemMessage(
        content=(
            "You are a helpful AI assistant who solves GAIA benchmark questions using step-by-step reasoning.\n"
            "Before answering, always think out loud and plan your approach.\n"
            "Use tools when you lack information or need external data. Only use audio or transcription tools if the user clearly provides or references an audio file.\n"
            "Do not assume the existence of files or media unless they are explicitly mentioned. Do not call tools like transcription unless an actual file or media reference is present.\n"
            "After every tool call, always analyze the result and continue reasoning to arrive at a final answer.\n"
            "If the question is unclear, incomplete, or missing context, respond with:  **'The question is incomplete โ€” please provide more information.'**"
            "Never treat tool outputs as final โ€” interpret them and continue solving the task step-by-step.\n"
            "When the question specifies an answer format (e.g., a number, list, or code), respond **only with the final answer** in the required format. Do not add explanations, quotes, or set notation. Output exactly what is requested.\n"
            "Finish with a clear and concise answer, such as 'The answer is: right'.\n"
            "\nAvailable tools:\n"
            f"{tool_descriptions}"
        )
    )

    input_msgs = [sys_msg] + state["messages"]
    print("\n๐Ÿง  Assistant received messages:")
    for msg in input_msgs:
        print(f"๐Ÿ”น {msg.__class__.__name__}: {getattr(msg, 'content', '')[:200]}")

    output = chat_with_tools.invoke(input_msgs)

    print("\n๐Ÿ—ฃ๏ธ Assistant response:")
    print("-" * 40)
    print(getattr(output, 'content', '')[:500])
    print("-" * 40)

    return {
        "messages": [output],
    }

# 4. Build the agent graph

def build_graph(max_steps: int = 5):
    builder = StateGraph(AgentState)

    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(all_tools))

    builder.add_edge(START, "assistant")
    builder.add_conditional_edges("assistant", tools_condition)
    builder.add_edge("tools", "assistant")

    graph = builder.compile()

    def limited_invoke(state, max_steps: int = 5, max_reasoning_steps_after_tool: int = 2):
        steps = 0
        reasoning_steps_since_last_tool = 0
    
        while steps < max_steps:
            print(f"\U0001f501 Step {steps + 1}")
            state = graph.invoke(state)
    
            for msg in state["messages"]:
                if isinstance(msg, AIMessage):
                    print("\n๐Ÿค– Assistant says:")
                    print("-" * 40)
                    print(msg.content.strip())
                    print("-" * 40)
    
            latest_message = state["messages"][-1] if state["messages"] else None
    
            if isinstance(latest_message, AIMessage):
                if latest_message.tool_calls:
                    print("๐Ÿ”„ Tool call detected โ€” continuing loop.")
                    reasoning_steps_since_last_tool = 0  # reset counter
                else:
                    reasoning_steps_since_last_tool += 1
                    print(f"๐Ÿง  No tool call โ€” reasoning step #{reasoning_steps_since_last_tool}")

                    # ๐Ÿ› ๏ธ Handle reverse_sentence manually
                    if "reverse_sentence" in latest_message.content.lower():
                        # Try to find the ToolMessage output
                        tool_outputs = [msg for msg in state["messages"] if msg.type == "tool"]
                        if tool_outputs:
                            reversed_text = tool_outputs[-1].content.strip()
                            print(f"๐Ÿ” Re-feeding reversed message:\n{reversed_text}")
                            state["messages"].append(HumanMessage(content=reversed_text))
                            continue  # loop again with new input
    
                    if reasoning_steps_since_last_tool >= max_reasoning_steps_after_tool:
                        print("โœ… Final answer assumed after sufficient reasoning.")
                        break
    
            steps += 1
    
        return state

    return limited_invoke

# 5. BasicAgent class

# class BasicAgent:
#     def __init__(self, max_steps: int = 5):
#         self.graph = build_graph(max_steps)

#     def __call__(self, question: str) -> str:
#         response = self.graph({"messages": [HumanMessage(content=question)]})
#         if response.get("messages"):
#             final_message = response["messages"][-1]
#             return final_message.content if hasattr(final_message, "content") else "No final message."
#         else:
#             return "No response."

class BasicAgent:
    def __init__(self, max_steps: int = 5):
        self.graph = build_graph(max_steps)

    def __call__(self, question: str) -> str:
        response = self.graph({"messages": [HumanMessage(content=question)]})
        if response.get("messages"):
            final_message = response["messages"][-1]
            raw_content = final_message.content if hasattr(final_message, "content") else "No final message."
            return extract_gaia_answer(raw_content)
        else:
            return "No response."


if __name__ == "__main__":
    agent = BasicAgent()
    print(agent("What is the capital of France?"))