File size: 8,857 Bytes
fbec116 86b8466 fbec116 e04e3db 53f8f7c fbec116 d4eadfe e04e3db fbec116 5e77d41 fbec116 86b8466 fbec116 53f8f7c e04e3db 53f8f7c fbec116 d4eadfe d3c7a7f fbec116 e04e3db fbec116 e04e3db fbec116 e04e3db fbec116 53f8f7c e04e3db fbec116 d4eadfe fbec116 5e77d41 fbec116 e04e3db fbec116 d4eadfe e04e3db d4eadfe e04e3db d4eadfe e04e3db d4eadfe e04e3db d4eadfe 5e77d41 fbec116 5e77d41 fbec116 5e77d41 fbec116 e04e3db fbec116 e04e3db fbec116 d4eadfe fbec116 d4eadfe fbec116 e04e3db fbec116 e04e3db 53f8f7c fbec116 d4eadfe e04e3db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
import os
import warnings
from typing import Annotated, TypedDict
from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_community.cache import SQLiteCache
from langchain_core.globals import set_llm_cache
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.system import SystemMessage
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from pydantic import SecretStr
warnings.filterwarnings("ignore", category=UserWarning, module="langchain_tavily")
load_dotenv()
# from langchain_core.caches import InMemoryCache
# set_llm_cache(InMemoryCache())
set_llm_cache(SQLiteCache(database_path=".langchain_cache.db"))
# Initialize RAG vector store
CHROMA_PATH = "./chroma_gaia_db"
EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
VECTOR_STORE = Chroma(persist_directory=CHROMA_PATH, embedding_function=EMBEDDINGS)
class AgentState(TypedDict):
"""State passed between nods in the graph"""
messages: Annotated[list, add_messages]
def load_system_prompt() -> SystemMessage:
with open("system_prompt.txt", "r") as f:
system_prompt = f.read()
return SystemMessage(content=system_prompt)
SYSTEM_PROMPT: SystemMessage = load_system_prompt()
class GaiaAgent:
"""
A LangGraph agent for Gaia questions
"""
def __init__(self, model: str, temperature: float):
"""Initialize the agent with a specific model"""
import asyncio
from tools import get_tools
self.tools = asyncio.run(get_tools())
if model.startswith("glm"):
api_key = SecretStr(secret_value=os.getenv("ZAI_API_KEY", ""))
api_base = "https://api.z.ai/api/coding/paas/v4/"
else:
api_key = SecretStr(secret_value=os.getenv("OPENAI_API_KEY") or "")
api_base = None
self.llm = ChatOpenAI(
model=model, temperature=temperature, base_url=api_base, api_key=api_key
).bind_tools(self.tools)
self.graph = self._build_graph()
print(f"Initialized GaiaAgent with model: {model}, temperature: {temperature}")
print(f"Available tools: {[tool.name for tool in self.tools]}")
def _build_graph(self) -> CompiledStateGraph:
"""Build the state graph for the agent"""
graph = StateGraph(AgentState)
graph.add_node("agent", self._agent_node)
graph.add_node("tools", ToolNode(self.tools))
graph.add_edge(START, "agent")
graph.add_conditional_edges("agent", tools_condition)
graph.add_edge("tools", "agent")
memory = MemorySaver()
return graph.compile(checkpointer=memory)
def _retriever_node(self, state: AgentState) -> AgentState:
"""Retrieve similar questions and inject solving strategy into the question."""
original_question = state["messages"][0].content
similar_docs = VECTOR_STORE.similarity_search(original_question, k=1)
if similar_docs:
doc = similar_docs[0]
steps = (
doc.page_content.split("Steps to solve:")[-1]
.split("Tools needed:")[0]
.strip()
)
tools = doc.metadata.get("tools", "")
# Build enhanced question with strategy
enhanced_question = f"""{original_question}
---
Strategy (from similar solved question):
{steps}
Tools needed: {tools}
Follow a similar approach to solve the question above."""
enhanced_msg = HumanMessage(content=enhanced_question)
return {"messages": [SYSTEM_PROMPT, enhanced_msg]}
return {"messages": [SYSTEM_PROMPT] + state["messages"]}
def _tools_node(self, state: AgentState) -> AgentState:
"""Execute tools and log results."""
tool_node = ToolNode(self.tools)
result = tool_node.invoke(state)
# Log tool results and check for answers
for msg in result.get("messages", []):
content = getattr(msg, "content", str(msg))
name = getattr(msg, "name", "unknown")
print(f" Tool result [{name}]: {content[:300]}...")
return result
async def __call__(self, question: str) -> str:
"""
Run the agent on a given question and return the answer
Args:
question (str): The input question to the agent
Returns:
str: The agent's answer to the question
"""
print(f"\n{'='*60}")
print(f"Agent received question: {question[:100]}...")
print(f"{'='*60}\n")
initial_state = {
"messages": [HumanMessage(content=question)],
}
try:
import uuid
thread_id = str(uuid.uuid4())
config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 50}
final_state = await self.graph.ainvoke(initial_state, config)
last_message = final_state["messages"][-1]
answer = (
str(last_message.content)
if hasattr(last_message, "content")
else str(last_message)
)
# Clean up answer - extract from tags if present
answer = self._clean_answer(answer)
print(f"Agent final response: {answer[:200]}...\n")
return answer
except Exception as e:
print(f"Error during agent execution: {e}")
return f"AGENT ERROR: {e}"
def _clean_answer(self, answer: str) -> str:
"""Extract clean answer from various formats."""
import re
# Extract from <solution>...</solution>
match = re.search(r"<solution>(.*?)</solution>", answer, re.DOTALL)
if match:
return match.group(1).strip()
# Extract from FINAL ANSWER: ... (to end of line or string)
match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", answer, re.IGNORECASE)
if match:
return match.group(1).strip()
# Extract from **FINAL ANSWER:** or similar markdown
match = re.search(
r"\*\*FINAL ANSWER:?\*\*:?\s*(.+?)(?:\n|$)", answer, re.IGNORECASE
)
if match:
return match.group(1).strip()
# If answer contains a colon followed by a list, extract just the list part
# e.g., "...ingredients: cornstarch, sugar, ..."
match = re.search(
r":\s*\n?\s*([a-z][a-z\s,]+(?:,\s*[a-z][a-z\s]+)+)\s*$",
answer,
re.IGNORECASE,
)
if match:
return match.group(1).strip()
# Last resort: if there's a clear comma-separated list at the end, extract it
lines = answer.strip().split("\n")
last_line = lines[-1].strip()
if "," in last_line and len(last_line) < 500:
# Check if it looks like a list (multiple comma-separated items)
items = [i.strip() for i in last_line.split(",")]
if len(items) >= 2 and all(len(i) < 100 for i in items):
return last_line
return answer.strip()
def _agent_node(self, state: AgentState) -> AgentState:
"""The main agent node that processes messages and generates responses"""
messages = state["messages"]
# Debug: show message count
print(f"\n[AGENT] Message count: {len(messages)}")
# Prepend system prompt if not already there
if not messages or not isinstance(messages[0], SystemMessage):
messages = [SYSTEM_PROMPT] + messages
# Print the full prompt/messages
print("[AGENT] === MESSAGES ===")
for i, msg in enumerate(messages):
msg_type = type(msg).__name__
content = (
str(msg.content)[:500] if hasattr(msg, "content") else str(msg)[:500]
)
print(f" [{i}] {msg_type}: {content}...")
print("[AGENT] === END MESSAGES ===\n")
response = self.llm.invoke(messages)
# Log what the agent is doing
if hasattr(response, "tool_calls") and response.tool_calls:
print(
f"[AGENT] Calling tools: {[tc['name'] for tc in response.tool_calls]}"
)
else:
content = (
str(response.content)[:200]
if hasattr(response, "content")
else str(response)[:200]
)
print(f"[AGENT] Final response: {content}...")
return {"messages": [response]}
# model="o3-mini"
MODEL = "glm-4.7"
BasicAgent = GaiaAgent(model=MODEL, temperature=1.0)
|