File size: 6,887 Bytes
8fc14db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import getpass
import os
import time
from typing import Annotated, Optional
from typing import TypedDict

from dotenv import load_dotenv
from langchain_core.tools.retriever import create_retriever_tool
from langchain_chroma import Chroma
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition

from tools import get_tools

load_dotenv()

MAX_AGENT_INVOKE_RETRIES = 3
INITIAL_AGENT_RETRY_BACKOFF = 1.0
INFERENCE_MODE = "hugging-face"  # Change to "hugging-face" or "open-ai" to use those providers instead


class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


class BasicAgent:
    def __init__(self):
        self.sys_msg = self.get_system_prompt()
        self.llm = self.get_llm()
        self.tools = self._load_tools()
        self.chat_with_tools = self.llm.bind_tools(self.tools)
        self._graph = self._build_graph()
        print("BasicAgent initialized.")

    def _load_tools(self):
        """Return tool list, appending a ChromaDB retriever tool if available."""
        tools = get_tools()
        try:
            embeddings = HuggingFaceEmbeddings(
                model_name="sentence-transformers/all-mpnet-base-v2"
            )
            vector_store = Chroma(
                collection_name="gaia_questions",
                embedding_function=embeddings,
                persist_directory="./chroma_db",
            )
            retriever_tool = create_retriever_tool(
                retriever=vector_store.as_retriever(),
                name="question_search",
                description=(
                    "Search for similar past questions. Returns solved examples with the answer "
                    "and which tools/strategies were used — useful for picking the right approach."
                ),
            )
            tools.append(retriever_tool)
        except Exception as e:
            print(f"Warning: could not initialise ChromaDB retriever: {e}")
        return tools

    def get_system_prompt(self):
        prompt_path = os.path.join(os.path.dirname(__file__), "system_prompt.md")
        with open(prompt_path, "r", encoding="utf-8") as f:
            system_prompt = f.read()
        return SystemMessage(content=system_prompt)

    def get_llm(self):
        global INFERENCE_MODE
        supported_modes = ["google", "hugging-face", "open-ai"]
        match INFERENCE_MODE.lower():
            case "google":
                model = "gemini-2.0-flash"
                if "GOOGLE_API_KEY" not in os.environ:
                    os.environ["GOOGLE_API_KEY"] = getpass.getpass(
                        "Please enter your Google AI API key: "
                    )
                return ChatGoogleGenerativeAI(model=model, temperature=0)
            case "hugging-face":
                repo_id = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
                return ChatHuggingFace(
                    llm=HuggingFaceEndpoint(
                        repo_id=repo_id,
                        task="text-generation",
                        temperature=0.01,  # HF serverless doesn't support temperature=0
                    ),
                    verbose=True,
                )
            case "open-ai":
                model = "gpt-4o-mini"
                if "OPENAI_API_KEY" not in os.environ:
                    os.environ["OPENAI_API_KEY"] = getpass.getpass(
                        "Please enter your OPEN AI API key: "
                    )
                return ChatOpenAI(model=model, temperature=0)
            case _:
                raise ValueError(
                    f"Invalid inference mode: {INFERENCE_MODE}. "
                    f"Please choose from supported modes: {', '.join(supported_modes)}"
                )

    def assistant(self, state: AgentState):
        return {
            "messages": [self.chat_with_tools.invoke([self.sys_msg] + state["messages"])]
        }

    def _build_graph(self):
        builder = StateGraph(AgentState)
        builder.add_node("assistant", self.assistant)
        builder.add_node("tools", ToolNode(self.tools))

        builder.add_edge(START, "assistant")
        builder.add_conditional_edges("assistant", tools_condition)
        builder.add_edge("tools", "assistant")
        return builder.compile()

    @property
    def graph(self):
        return self._graph

    def __call__(
        self,
        question: str,
        file_url: Optional[str] = None,
        file_name: Optional[str] = None,
    ) -> str:
        if file_url:
            file_ext = os.path.splitext(file_name)[1].lower()
            local_file_path = f"./files/{file_name}"
            prompt = (
                f"{question}\n\n"
                f"Attached file url:\n{file_url}\n\n"
                f"Attached file extension:\n{file_ext}\n\n"
                f"If file doesn't exist at {file_url}, you can access the file locally at {local_file_path}."
            )
        else:
            prompt = question

        messages = [HumanMessage(content=prompt)]
        response = self.invoke_agent_with_retries(messages)

        for m in response["messages"]:
            if len(m.content) < 1000:
                m.pretty_print()
            else:
                m.content = m.content[:500] + "..." + m.content[-500:]
                m.pretty_print()

        answer = response["messages"][-1].content
        if "FINAL ANSWER: " in answer:
            return answer.split("FINAL ANSWER: ")[1]
        return answer

    def invoke_agent_with_retries(self, messages: list[AnyMessage]):
        backoff = INITIAL_AGENT_RETRY_BACKOFF
        for attempt in range(1, MAX_AGENT_INVOKE_RETRIES + 1):
            try:
                return self.graph.invoke({"messages": messages})
            except Exception as exc:
                if attempt == MAX_AGENT_INVOKE_RETRIES:
                    print(f"Agent invocation failed after {attempt} attempts: {exc}")
                    raise
                print(
                    f"Agent invocation attempt {attempt} failed ({exc}); "
                    f"retrying in {backoff:.1f}s..."
                )
                time.sleep(backoff)
                backoff *= 2


# Stable runtime graph for LangSmith traceability
__all__ = ["BasicAgent", "get_agent", "get_graph"]

_AGENT_SINGLETON: Optional[BasicAgent] = None


def get_agent() -> BasicAgent:
    global _AGENT_SINGLETON
    if _AGENT_SINGLETON is None:
        _AGENT_SINGLETON = BasicAgent()
    return _AGENT_SINGLETON


def get_graph():
    return get_agent().graph