File size: 8,857 Bytes
fbec116
86b8466
fbec116
 
 
e04e3db
53f8f7c
 
fbec116
 
d4eadfe
e04e3db
 
fbec116
 
5e77d41
fbec116
 
86b8466
 
fbec116
 
 
53f8f7c
 
 
 
e04e3db
 
 
 
 
53f8f7c
fbec116
 
 
 
 
 
 
 
 
 
 
 
 
 
d4eadfe
d3c7a7f
fbec116
 
 
 
e04e3db
fbec116
e04e3db
fbec116
 
 
e04e3db
fbec116
 
53f8f7c
e04e3db
fbec116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4eadfe
fbec116
 
5e77d41
fbec116
 
e04e3db
 
fbec116
d4eadfe
e04e3db
 
d4eadfe
e04e3db
d4eadfe
 
e04e3db
 
 
 
 
d4eadfe
e04e3db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4eadfe
 
 
5e77d41
 
 
 
fbec116
5e77d41
 
 
 
 
fbec116
5e77d41
fbec116
e04e3db
fbec116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e04e3db
 
 
 
 
fbec116
 
 
 
 
 
 
 
 
d4eadfe
 
 
 
fbec116
 
 
 
 
 
d4eadfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbec116
 
 
e04e3db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbec116
e04e3db
 
 
 
 
 
 
 
 
 
 
 
 
 
53f8f7c
fbec116
 
d4eadfe
 
 
e04e3db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import os
import warnings
from typing import Annotated, TypedDict

from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_community.cache import SQLiteCache
from langchain_core.globals import set_llm_cache
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.system import SystemMessage
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from pydantic import SecretStr

warnings.filterwarnings("ignore", category=UserWarning, module="langchain_tavily")

load_dotenv()


# from langchain_core.caches import InMemoryCache
# set_llm_cache(InMemoryCache())
set_llm_cache(SQLiteCache(database_path=".langchain_cache.db"))

# Initialize RAG vector store
CHROMA_PATH = "./chroma_gaia_db"
EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
VECTOR_STORE = Chroma(persist_directory=CHROMA_PATH, embedding_function=EMBEDDINGS)


class AgentState(TypedDict):
    """State passed between nods in the graph"""

    messages: Annotated[list, add_messages]


def load_system_prompt() -> SystemMessage:
    with open("system_prompt.txt", "r") as f:
        system_prompt = f.read()
        return SystemMessage(content=system_prompt)


SYSTEM_PROMPT: SystemMessage = load_system_prompt()


class GaiaAgent:
    """
    A LangGraph agent for Gaia questions
    """

    def __init__(self, model: str, temperature: float):
        """Initialize the agent with a specific model"""
        import asyncio

        from tools import get_tools

        self.tools = asyncio.run(get_tools())

        if model.startswith("glm"):
            api_key = SecretStr(secret_value=os.getenv("ZAI_API_KEY", ""))
            api_base = "https://api.z.ai/api/coding/paas/v4/"
        else:
            api_key = SecretStr(secret_value=os.getenv("OPENAI_API_KEY") or "")
            api_base = None

        self.llm = ChatOpenAI(
            model=model, temperature=temperature, base_url=api_base, api_key=api_key
        ).bind_tools(self.tools)

        self.graph = self._build_graph()

        print(f"Initialized GaiaAgent with model: {model}, temperature: {temperature}")
        print(f"Available tools: {[tool.name for tool in self.tools]}")

    def _build_graph(self) -> CompiledStateGraph:
        """Build the state graph for the agent"""

        graph = StateGraph(AgentState)

        graph.add_node("agent", self._agent_node)
        graph.add_node("tools", ToolNode(self.tools))

        graph.add_edge(START, "agent")
        graph.add_conditional_edges("agent", tools_condition)
        graph.add_edge("tools", "agent")

        memory = MemorySaver()
        return graph.compile(checkpointer=memory)

    def _retriever_node(self, state: AgentState) -> AgentState:
        """Retrieve similar questions and inject solving strategy into the question."""
        original_question = state["messages"][0].content

        similar_docs = VECTOR_STORE.similarity_search(original_question, k=1)

        if similar_docs:
            doc = similar_docs[0]
            steps = (
                doc.page_content.split("Steps to solve:")[-1]
                .split("Tools needed:")[0]
                .strip()
            )
            tools = doc.metadata.get("tools", "")

            # Build enhanced question with strategy
            enhanced_question = f"""{original_question}

---
Strategy (from similar solved question):
{steps}

Tools needed: {tools}

Follow a similar approach to solve the question above."""

            enhanced_msg = HumanMessage(content=enhanced_question)
            return {"messages": [SYSTEM_PROMPT, enhanced_msg]}

        return {"messages": [SYSTEM_PROMPT] + state["messages"]}

    def _tools_node(self, state: AgentState) -> AgentState:
        """Execute tools and log results."""
        tool_node = ToolNode(self.tools)
        result = tool_node.invoke(state)

        # Log tool results and check for answers
        for msg in result.get("messages", []):
            content = getattr(msg, "content", str(msg))
            name = getattr(msg, "name", "unknown")
            print(f"  Tool result [{name}]: {content[:300]}...")

        return result

    async def __call__(self, question: str) -> str:
        """
        Run the agent on a given question and return the answer

        Args:
            question (str): The input question to the agent

        Returns:
            str: The agent's answer to the question
        """

        print(f"\n{'='*60}")
        print(f"Agent received question: {question[:100]}...")
        print(f"{'='*60}\n")

        initial_state = {
            "messages": [HumanMessage(content=question)],
        }

        try:
            import uuid

            thread_id = str(uuid.uuid4())
            config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 50}
            final_state = await self.graph.ainvoke(initial_state, config)

            last_message = final_state["messages"][-1]

            answer = (
                str(last_message.content)
                if hasattr(last_message, "content")
                else str(last_message)
            )

            # Clean up answer - extract from tags if present
            answer = self._clean_answer(answer)

            print(f"Agent final response: {answer[:200]}...\n")

            return answer
        except Exception as e:
            print(f"Error during agent execution: {e}")
            return f"AGENT ERROR: {e}"

    def _clean_answer(self, answer: str) -> str:
        """Extract clean answer from various formats."""
        import re

        # Extract from <solution>...</solution>
        match = re.search(r"<solution>(.*?)</solution>", answer, re.DOTALL)
        if match:
            return match.group(1).strip()

        # Extract from FINAL ANSWER: ... (to end of line or string)
        match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", answer, re.IGNORECASE)
        if match:
            return match.group(1).strip()

        # Extract from **FINAL ANSWER:** or similar markdown
        match = re.search(
            r"\*\*FINAL ANSWER:?\*\*:?\s*(.+?)(?:\n|$)", answer, re.IGNORECASE
        )
        if match:
            return match.group(1).strip()

        # If answer contains a colon followed by a list, extract just the list part
        # e.g., "...ingredients: cornstarch, sugar, ..."
        match = re.search(
            r":\s*\n?\s*([a-z][a-z\s,]+(?:,\s*[a-z][a-z\s]+)+)\s*$",
            answer,
            re.IGNORECASE,
        )
        if match:
            return match.group(1).strip()

        # Last resort: if there's a clear comma-separated list at the end, extract it
        lines = answer.strip().split("\n")
        last_line = lines[-1].strip()
        if "," in last_line and len(last_line) < 500:
            # Check if it looks like a list (multiple comma-separated items)
            items = [i.strip() for i in last_line.split(",")]
            if len(items) >= 2 and all(len(i) < 100 for i in items):
                return last_line

        return answer.strip()

    def _agent_node(self, state: AgentState) -> AgentState:
        """The main agent node that processes messages and generates responses"""
        messages = state["messages"]

        # Debug: show message count
        print(f"\n[AGENT] Message count: {len(messages)}")

        # Prepend system prompt if not already there
        if not messages or not isinstance(messages[0], SystemMessage):
            messages = [SYSTEM_PROMPT] + messages

        # Print the full prompt/messages
        print("[AGENT] === MESSAGES ===")
        for i, msg in enumerate(messages):
            msg_type = type(msg).__name__
            content = (
                str(msg.content)[:500] if hasattr(msg, "content") else str(msg)[:500]
            )
            print(f"  [{i}] {msg_type}: {content}...")
        print("[AGENT] === END MESSAGES ===\n")

        response = self.llm.invoke(messages)

        # Log what the agent is doing
        if hasattr(response, "tool_calls") and response.tool_calls:
            print(
                f"[AGENT] Calling tools: {[tc['name'] for tc in response.tool_calls]}"
            )
        else:
            content = (
                str(response.content)[:200]
                if hasattr(response, "content")
                else str(response)[:200]
            )
            print(f"[AGENT] Final response: {content}...")

        return {"messages": [response]}


# model="o3-mini"
MODEL = "glm-4.7"

BasicAgent = GaiaAgent(model=MODEL, temperature=1.0)