Spaces:
Sleeping
Sleeping
File size: 6,887 Bytes
8fc14db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import getpass
import os
import time
from typing import Annotated, Optional
from typing import TypedDict
from dotenv import load_dotenv
from langchain_core.tools.retriever import create_retriever_tool
from langchain_chroma import Chroma
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from tools import get_tools
load_dotenv()
MAX_AGENT_INVOKE_RETRIES = 3
INITIAL_AGENT_RETRY_BACKOFF = 1.0
INFERENCE_MODE = "hugging-face" # Change to "hugging-face" or "open-ai" to use those providers instead
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
class BasicAgent:
def __init__(self):
self.sys_msg = self.get_system_prompt()
self.llm = self.get_llm()
self.tools = self._load_tools()
self.chat_with_tools = self.llm.bind_tools(self.tools)
self._graph = self._build_graph()
print("BasicAgent initialized.")
def _load_tools(self):
"""Return tool list, appending a ChromaDB retriever tool if available."""
tools = get_tools()
try:
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2"
)
vector_store = Chroma(
collection_name="gaia_questions",
embedding_function=embeddings,
persist_directory="./chroma_db",
)
retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="question_search",
description=(
"Search for similar past questions. Returns solved examples with the answer "
"and which tools/strategies were used — useful for picking the right approach."
),
)
tools.append(retriever_tool)
except Exception as e:
print(f"Warning: could not initialise ChromaDB retriever: {e}")
return tools
def get_system_prompt(self):
prompt_path = os.path.join(os.path.dirname(__file__), "system_prompt.md")
with open(prompt_path, "r", encoding="utf-8") as f:
system_prompt = f.read()
return SystemMessage(content=system_prompt)
def get_llm(self):
global INFERENCE_MODE
supported_modes = ["google", "hugging-face", "open-ai"]
match INFERENCE_MODE.lower():
case "google":
model = "gemini-2.0-flash"
if "GOOGLE_API_KEY" not in os.environ:
os.environ["GOOGLE_API_KEY"] = getpass.getpass(
"Please enter your Google AI API key: "
)
return ChatGoogleGenerativeAI(model=model, temperature=0)
case "hugging-face":
repo_id = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
return ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id=repo_id,
task="text-generation",
temperature=0.01, # HF serverless doesn't support temperature=0
),
verbose=True,
)
case "open-ai":
model = "gpt-4o-mini"
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = getpass.getpass(
"Please enter your OPEN AI API key: "
)
return ChatOpenAI(model=model, temperature=0)
case _:
raise ValueError(
f"Invalid inference mode: {INFERENCE_MODE}. "
f"Please choose from supported modes: {', '.join(supported_modes)}"
)
def assistant(self, state: AgentState):
return {
"messages": [self.chat_with_tools.invoke([self.sys_msg] + state["messages"])]
}
def _build_graph(self):
builder = StateGraph(AgentState)
builder.add_node("assistant", self.assistant)
builder.add_node("tools", ToolNode(self.tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
@property
def graph(self):
return self._graph
def __call__(
self,
question: str,
file_url: Optional[str] = None,
file_name: Optional[str] = None,
) -> str:
if file_url:
file_ext = os.path.splitext(file_name)[1].lower()
local_file_path = f"./files/{file_name}"
prompt = (
f"{question}\n\n"
f"Attached file url:\n{file_url}\n\n"
f"Attached file extension:\n{file_ext}\n\n"
f"If file doesn't exist at {file_url}, you can access the file locally at {local_file_path}."
)
else:
prompt = question
messages = [HumanMessage(content=prompt)]
response = self.invoke_agent_with_retries(messages)
for m in response["messages"]:
if len(m.content) < 1000:
m.pretty_print()
else:
m.content = m.content[:500] + "..." + m.content[-500:]
m.pretty_print()
answer = response["messages"][-1].content
if "FINAL ANSWER: " in answer:
return answer.split("FINAL ANSWER: ")[1]
return answer
def invoke_agent_with_retries(self, messages: list[AnyMessage]):
backoff = INITIAL_AGENT_RETRY_BACKOFF
for attempt in range(1, MAX_AGENT_INVOKE_RETRIES + 1):
try:
return self.graph.invoke({"messages": messages})
except Exception as exc:
if attempt == MAX_AGENT_INVOKE_RETRIES:
print(f"Agent invocation failed after {attempt} attempts: {exc}")
raise
print(
f"Agent invocation attempt {attempt} failed ({exc}); "
f"retrying in {backoff:.1f}s..."
)
time.sleep(backoff)
backoff *= 2
# Stable runtime graph for LangSmith traceability
__all__ = ["BasicAgent", "get_agent", "get_graph"]
_AGENT_SINGLETON: Optional[BasicAgent] = None
def get_agent() -> BasicAgent:
global _AGENT_SINGLETON
if _AGENT_SINGLETON is None:
_AGENT_SINGLETON = BasicAgent()
return _AGENT_SINGLETON
def get_graph():
return get_agent().graph
|