Spaces:
Sleeping
Sleeping
| import getpass | |
| import os | |
| import time | |
| from typing import Annotated, Optional | |
| from typing import TypedDict | |
| from dotenv import load_dotenv | |
| from langchain_core.tools.retriever import create_retriever_tool | |
| from langchain_chroma import Chroma | |
| from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings, HuggingFaceEndpoint | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.graph import START, StateGraph | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from tools import get_tools | |
| load_dotenv() | |
| MAX_AGENT_INVOKE_RETRIES = 3 | |
| INITIAL_AGENT_RETRY_BACKOFF = 1.0 | |
| INFERENCE_MODE = "hugging-face" # Change to "hugging-face" or "open-ai" to use those providers instead | |
| class AgentState(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| class BasicAgent: | |
| def __init__(self): | |
| self.sys_msg = self.get_system_prompt() | |
| self.llm = self.get_llm() | |
| self.tools = self._load_tools() | |
| self.chat_with_tools = self.llm.bind_tools(self.tools) | |
| self._graph = self._build_graph() | |
| print("BasicAgent initialized.") | |
| def _load_tools(self): | |
| """Return tool list, appending a ChromaDB retriever tool if available.""" | |
| tools = get_tools() | |
| try: | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-mpnet-base-v2" | |
| ) | |
| vector_store = Chroma( | |
| collection_name="gaia_questions", | |
| embedding_function=embeddings, | |
| persist_directory="./chroma_db", | |
| ) | |
| retriever_tool = create_retriever_tool( | |
| retriever=vector_store.as_retriever(), | |
| name="question_search", | |
| description=( | |
| "Search for similar past questions. Returns solved examples with the answer " | |
| "and which tools/strategies were used — useful for picking the right approach." | |
| ), | |
| ) | |
| tools.append(retriever_tool) | |
| except Exception as e: | |
| print(f"Warning: could not initialise ChromaDB retriever: {e}") | |
| return tools | |
| def get_system_prompt(self): | |
| prompt_path = os.path.join(os.path.dirname(__file__), "system_prompt.md") | |
| with open(prompt_path, "r", encoding="utf-8") as f: | |
| system_prompt = f.read() | |
| return SystemMessage(content=system_prompt) | |
| def get_llm(self): | |
| global INFERENCE_MODE | |
| supported_modes = ["google", "hugging-face", "open-ai"] | |
| match INFERENCE_MODE.lower(): | |
| case "google": | |
| model = "gemini-2.0-flash" | |
| if "GOOGLE_API_KEY" not in os.environ: | |
| os.environ["GOOGLE_API_KEY"] = getpass.getpass( | |
| "Please enter your Google AI API key: " | |
| ) | |
| return ChatGoogleGenerativeAI(model=model, temperature=0) | |
| case "hugging-face": | |
| repo_id = "Qwen/Qwen3-Coder-30B-A3B-Instruct" | |
| return ChatHuggingFace( | |
| llm=HuggingFaceEndpoint( | |
| repo_id=repo_id, | |
| task="text-generation", | |
| temperature=0.01, # HF serverless doesn't support temperature=0 | |
| ), | |
| verbose=True, | |
| ) | |
| case "open-ai": | |
| model = "gpt-4o-mini" | |
| if "OPENAI_API_KEY" not in os.environ: | |
| os.environ["OPENAI_API_KEY"] = getpass.getpass( | |
| "Please enter your OPEN AI API key: " | |
| ) | |
| return ChatOpenAI(model=model, temperature=0) | |
| case _: | |
| raise ValueError( | |
| f"Invalid inference mode: {INFERENCE_MODE}. " | |
| f"Please choose from supported modes: {', '.join(supported_modes)}" | |
| ) | |
| def assistant(self, state: AgentState): | |
| return { | |
| "messages": [self.chat_with_tools.invoke([self.sys_msg] + state["messages"])] | |
| } | |
| def _build_graph(self): | |
| builder = StateGraph(AgentState) | |
| builder.add_node("assistant", self.assistant) | |
| builder.add_node("tools", ToolNode(self.tools)) | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges("assistant", tools_condition) | |
| builder.add_edge("tools", "assistant") | |
| return builder.compile() | |
| def graph(self): | |
| return self._graph | |
| def __call__( | |
| self, | |
| question: str, | |
| file_url: Optional[str] = None, | |
| file_name: Optional[str] = None, | |
| ) -> str: | |
| if file_url: | |
| file_ext = os.path.splitext(file_name)[1].lower() | |
| local_file_path = f"./files/{file_name}" | |
| prompt = ( | |
| f"{question}\n\n" | |
| f"Attached file url:\n{file_url}\n\n" | |
| f"Attached file extension:\n{file_ext}\n\n" | |
| f"If file doesn't exist at {file_url}, you can access the file locally at {local_file_path}." | |
| ) | |
| else: | |
| prompt = question | |
| messages = [HumanMessage(content=prompt)] | |
| response = self.invoke_agent_with_retries(messages) | |
| for m in response["messages"]: | |
| if len(m.content) < 1000: | |
| m.pretty_print() | |
| else: | |
| m.content = m.content[:500] + "..." + m.content[-500:] | |
| m.pretty_print() | |
| answer = response["messages"][-1].content | |
| if "FINAL ANSWER: " in answer: | |
| return answer.split("FINAL ANSWER: ")[1] | |
| return answer | |
| def invoke_agent_with_retries(self, messages: list[AnyMessage]): | |
| backoff = INITIAL_AGENT_RETRY_BACKOFF | |
| for attempt in range(1, MAX_AGENT_INVOKE_RETRIES + 1): | |
| try: | |
| return self.graph.invoke({"messages": messages}) | |
| except Exception as exc: | |
| if attempt == MAX_AGENT_INVOKE_RETRIES: | |
| print(f"Agent invocation failed after {attempt} attempts: {exc}") | |
| raise | |
| print( | |
| f"Agent invocation attempt {attempt} failed ({exc}); " | |
| f"retrying in {backoff:.1f}s..." | |
| ) | |
| time.sleep(backoff) | |
| backoff *= 2 | |
| # Stable runtime graph for LangSmith traceability | |
| __all__ = ["BasicAgent", "get_agent", "get_graph"] | |
| _AGENT_SINGLETON: Optional[BasicAgent] = None | |
| def get_agent() -> BasicAgent: | |
| global _AGENT_SINGLETON | |
| if _AGENT_SINGLETON is None: | |
| _AGENT_SINGLETON = BasicAgent() | |
| return _AGENT_SINGLETON | |
| def get_graph(): | |
| return get_agent().graph | |