import getpass import os import time from typing import Annotated, Optional from typing import TypedDict from dotenv import load_dotenv from langchain_core.tools.retriever import create_retriever_tool from langchain_chroma import Chroma from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage from langchain_google_genai import ChatGoogleGenerativeAI from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings, HuggingFaceEndpoint from langchain_openai import ChatOpenAI from langgraph.graph import START, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, tools_condition from tools import get_tools load_dotenv() MAX_AGENT_INVOKE_RETRIES = 3 INITIAL_AGENT_RETRY_BACKOFF = 1.0 INFERENCE_MODE = "hugging-face" # Change to "hugging-face" or "open-ai" to use those providers instead class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] class BasicAgent: def __init__(self): self.sys_msg = self.get_system_prompt() self.llm = self.get_llm() self.tools = self._load_tools() self.chat_with_tools = self.llm.bind_tools(self.tools) self._graph = self._build_graph() print("BasicAgent initialized.") def _load_tools(self): """Return tool list, appending a ChromaDB retriever tool if available.""" tools = get_tools() try: embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-mpnet-base-v2" ) vector_store = Chroma( collection_name="gaia_questions", embedding_function=embeddings, persist_directory="./chroma_db", ) retriever_tool = create_retriever_tool( retriever=vector_store.as_retriever(), name="question_search", description=( "Search for similar past questions. Returns solved examples with the answer " "and which tools/strategies were used — useful for picking the right approach." ), ) tools.append(retriever_tool) except Exception as e: print(f"Warning: could not initialise ChromaDB retriever: {e}") return tools def get_system_prompt(self): prompt_path = os.path.join(os.path.dirname(__file__), "system_prompt.md") with open(prompt_path, "r", encoding="utf-8") as f: system_prompt = f.read() return SystemMessage(content=system_prompt) def get_llm(self): global INFERENCE_MODE supported_modes = ["google", "hugging-face", "open-ai"] match INFERENCE_MODE.lower(): case "google": model = "gemini-2.0-flash" if "GOOGLE_API_KEY" not in os.environ: os.environ["GOOGLE_API_KEY"] = getpass.getpass( "Please enter your Google AI API key: " ) return ChatGoogleGenerativeAI(model=model, temperature=0) case "hugging-face": repo_id = "Qwen/Qwen3-Coder-30B-A3B-Instruct" return ChatHuggingFace( llm=HuggingFaceEndpoint( repo_id=repo_id, task="text-generation", temperature=0.01, # HF serverless doesn't support temperature=0 ), verbose=True, ) case "open-ai": model = "gpt-4o-mini" if "OPENAI_API_KEY" not in os.environ: os.environ["OPENAI_API_KEY"] = getpass.getpass( "Please enter your OPEN AI API key: " ) return ChatOpenAI(model=model, temperature=0) case _: raise ValueError( f"Invalid inference mode: {INFERENCE_MODE}. " f"Please choose from supported modes: {', '.join(supported_modes)}" ) def assistant(self, state: AgentState): return { "messages": [self.chat_with_tools.invoke([self.sys_msg] + state["messages"])] } def _build_graph(self): builder = StateGraph(AgentState) builder.add_node("assistant", self.assistant) builder.add_node("tools", ToolNode(self.tools)) builder.add_edge(START, "assistant") builder.add_conditional_edges("assistant", tools_condition) builder.add_edge("tools", "assistant") return builder.compile() @property def graph(self): return self._graph def __call__( self, question: str, file_url: Optional[str] = None, file_name: Optional[str] = None, ) -> str: if file_url: file_ext = os.path.splitext(file_name)[1].lower() local_file_path = f"./files/{file_name}" prompt = ( f"{question}\n\n" f"Attached file url:\n{file_url}\n\n" f"Attached file extension:\n{file_ext}\n\n" f"If file doesn't exist at {file_url}, you can access the file locally at {local_file_path}." ) else: prompt = question messages = [HumanMessage(content=prompt)] response = self.invoke_agent_with_retries(messages) for m in response["messages"]: if len(m.content) < 1000: m.pretty_print() else: m.content = m.content[:500] + "..." + m.content[-500:] m.pretty_print() answer = response["messages"][-1].content if "FINAL ANSWER: " in answer: return answer.split("FINAL ANSWER: ")[1] return answer def invoke_agent_with_retries(self, messages: list[AnyMessage]): backoff = INITIAL_AGENT_RETRY_BACKOFF for attempt in range(1, MAX_AGENT_INVOKE_RETRIES + 1): try: return self.graph.invoke({"messages": messages}) except Exception as exc: if attempt == MAX_AGENT_INVOKE_RETRIES: print(f"Agent invocation failed after {attempt} attempts: {exc}") raise print( f"Agent invocation attempt {attempt} failed ({exc}); " f"retrying in {backoff:.1f}s..." ) time.sleep(backoff) backoff *= 2 # Stable runtime graph for LangSmith traceability __all__ = ["BasicAgent", "get_agent", "get_graph"] _AGENT_SINGLETON: Optional[BasicAgent] = None def get_agent() -> BasicAgent: global _AGENT_SINGLETON if _AGENT_SINGLETON is None: _AGENT_SINGLETON = BasicAgent() return _AGENT_SINGLETON def get_graph(): return get_agent().graph