Gaia-agent / agent.py
NZ
add agent
8fc14db
import getpass
import os
import time
from typing import Annotated, Optional
from typing import TypedDict
from dotenv import load_dotenv
from langchain_core.tools.retriever import create_retriever_tool
from langchain_chroma import Chroma
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from tools import get_tools
load_dotenv()
MAX_AGENT_INVOKE_RETRIES = 3
INITIAL_AGENT_RETRY_BACKOFF = 1.0
INFERENCE_MODE = "hugging-face" # Change to "hugging-face" or "open-ai" to use those providers instead
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
class BasicAgent:
def __init__(self):
self.sys_msg = self.get_system_prompt()
self.llm = self.get_llm()
self.tools = self._load_tools()
self.chat_with_tools = self.llm.bind_tools(self.tools)
self._graph = self._build_graph()
print("BasicAgent initialized.")
def _load_tools(self):
"""Return tool list, appending a ChromaDB retriever tool if available."""
tools = get_tools()
try:
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2"
)
vector_store = Chroma(
collection_name="gaia_questions",
embedding_function=embeddings,
persist_directory="./chroma_db",
)
retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="question_search",
description=(
"Search for similar past questions. Returns solved examples with the answer "
"and which tools/strategies were used — useful for picking the right approach."
),
)
tools.append(retriever_tool)
except Exception as e:
print(f"Warning: could not initialise ChromaDB retriever: {e}")
return tools
def get_system_prompt(self):
prompt_path = os.path.join(os.path.dirname(__file__), "system_prompt.md")
with open(prompt_path, "r", encoding="utf-8") as f:
system_prompt = f.read()
return SystemMessage(content=system_prompt)
def get_llm(self):
global INFERENCE_MODE
supported_modes = ["google", "hugging-face", "open-ai"]
match INFERENCE_MODE.lower():
case "google":
model = "gemini-2.0-flash"
if "GOOGLE_API_KEY" not in os.environ:
os.environ["GOOGLE_API_KEY"] = getpass.getpass(
"Please enter your Google AI API key: "
)
return ChatGoogleGenerativeAI(model=model, temperature=0)
case "hugging-face":
repo_id = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
return ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id=repo_id,
task="text-generation",
temperature=0.01, # HF serverless doesn't support temperature=0
),
verbose=True,
)
case "open-ai":
model = "gpt-4o-mini"
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = getpass.getpass(
"Please enter your OPEN AI API key: "
)
return ChatOpenAI(model=model, temperature=0)
case _:
raise ValueError(
f"Invalid inference mode: {INFERENCE_MODE}. "
f"Please choose from supported modes: {', '.join(supported_modes)}"
)
def assistant(self, state: AgentState):
return {
"messages": [self.chat_with_tools.invoke([self.sys_msg] + state["messages"])]
}
def _build_graph(self):
builder = StateGraph(AgentState)
builder.add_node("assistant", self.assistant)
builder.add_node("tools", ToolNode(self.tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
@property
def graph(self):
return self._graph
def __call__(
self,
question: str,
file_url: Optional[str] = None,
file_name: Optional[str] = None,
) -> str:
if file_url:
file_ext = os.path.splitext(file_name)[1].lower()
local_file_path = f"./files/{file_name}"
prompt = (
f"{question}\n\n"
f"Attached file url:\n{file_url}\n\n"
f"Attached file extension:\n{file_ext}\n\n"
f"If file doesn't exist at {file_url}, you can access the file locally at {local_file_path}."
)
else:
prompt = question
messages = [HumanMessage(content=prompt)]
response = self.invoke_agent_with_retries(messages)
for m in response["messages"]:
if len(m.content) < 1000:
m.pretty_print()
else:
m.content = m.content[:500] + "..." + m.content[-500:]
m.pretty_print()
answer = response["messages"][-1].content
if "FINAL ANSWER: " in answer:
return answer.split("FINAL ANSWER: ")[1]
return answer
def invoke_agent_with_retries(self, messages: list[AnyMessage]):
backoff = INITIAL_AGENT_RETRY_BACKOFF
for attempt in range(1, MAX_AGENT_INVOKE_RETRIES + 1):
try:
return self.graph.invoke({"messages": messages})
except Exception as exc:
if attempt == MAX_AGENT_INVOKE_RETRIES:
print(f"Agent invocation failed after {attempt} attempts: {exc}")
raise
print(
f"Agent invocation attempt {attempt} failed ({exc}); "
f"retrying in {backoff:.1f}s..."
)
time.sleep(backoff)
backoff *= 2
# Stable runtime graph for LangSmith traceability
__all__ = ["BasicAgent", "get_agent", "get_graph"]
_AGENT_SINGLETON: Optional[BasicAgent] = None
def get_agent() -> BasicAgent:
global _AGENT_SINGLETON
if _AGENT_SINGLETON is None:
_AGENT_SINGLETON = BasicAgent()
return _AGENT_SINGLETON
def get_graph():
return get_agent().graph