File size: 5,477 Bytes
9032799
1aa9547
387c120
1aa9547
fa7f5a2
9032799
1aa9547
9032799
8989d02
e4167da
9032799
 
 
 
 
 
 
 
 
 
9096e64
 
9032799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9096e64
 
 
 
 
9032799
31be711
8989d02
9032799
e4167da
 
 
 
 
 
 
 
 
 
9032799
 
8989d02
 
9032799
 
8989d02
 
 
 
9032799
 
 
 
 
8989d02
 
 
 
fa7f5a2
387c120
8989d02
 
fa7f5a2
8989d02
fa7f5a2
8989d02
fa7f5a2
 
8989d02
 
 
 
 
 
 
 
 
 
 
 
 
 
9032799
 
 
e4167da
9032799
e4167da
9032799
 
 
 
 
 
 
 
 
e4167da
8989d02
9032799
 
 
 
 
8989d02
 
 
 
 
9032799
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import time
from langchain.chains import RetrievalQA
from langchain_chroma import Chroma
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langgraph.graph import StateGraph, START
from langgraph.graph.message import MessagesState
from langgraph.prebuilt import ToolNode, tools_condition

from agent_tools import *

load_dotenv()

sys_msg = SystemMessage(
    content=
    """
    You are a helpful assistant tasked with answering questions using a set of tools. When given a question, follow these steps:
    1. Create a clear, step-by-step plan to solve the question.
    2. If a tool is necessary, select the most appropriate tool based on its functionality. If one tool isn't working, use another with similar functionality.
    3. If a question depends on external numeric or factual data not provided, automatically use your search tools to find it online before answering.
    4. Base your answer on tool outputs and any provided files.
    5. Execute your plan and provide the response in the following format:

    FINAL ANSWER: [YOUR FINAL ANSWER]

    Your final answer should be:

    - A number (without commas or units unless explicitly requested),
    - A short string (avoid articles, abbreviations, and use plain text for digits unless otherwise specified),
    - A comma-separated list (apply the formatting rules above for each element, with exactly one space after each comma).

    Ensure that your answer is concise and follows the task instructions strictly. If the answer is more complex, break it down in a way that follows the format.
    Begin your response with "FINAL ANSWER: " followed by the answer, and nothing else.
    """
)

class CUSTOM_AGENT:
    """
        A simple deterministic agent that leverages our tools directly and avoids
        LLM refusal fallbacks.
        """

    def __init__(self):
        self.llm = ChatOpenAI(model="gpt-5", api_key=os.getenv("OPENAI_API_KEY"), temperature=0)
        self.tools = TOOLS
        self.llm_with_tools = self.llm.bind_tools(self.tools)
        self.sys_msg = sys_msg
        embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
        persist_directory = "chroma_db"
        self.vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
        self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 3})
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            retriever=self.retriever,
            return_source_documents=True
        )

    def _graph_compile(self):
        builder = StateGraph(MessagesState)
        builder.add_node("retriever", self._retriever_node)
        builder.add_node("assistant", self._assistant)
        builder.add_node("tools", ToolNode(self.tools))

        builder.add_edge(START, "retriever")
        builder.add_edge("retriever", "assistant")

        builder.add_conditional_edges(
            "assistant",
            tools_condition,
        )
        builder.add_edge("tools", "assistant")
        return builder.compile()

    def _retriever_node(self, state: MessagesState):
        """Retriever node"""
        question = state["messages"][ -1 ].content
        docs = self.retriever.invoke(question)

        if docs:
            context = "\n\n".join([d.page_content for d in docs])
        else:
            context = "No relevant documents found"

        combined = f"Context:\n{context}\n\nQuestion:\n{question}"
        return {"messages": [HumanMessage(content=combined)]}

    def _assistant(self, state: MessagesState):
        """Assistant node"""
        if not any(isinstance(m, SystemMessage) for m in state["messages"]):
            messages = [self.sys_msg] + state["messages"]
        else:
            messages = state["messages"]

        llm_response = self.llm_with_tools.invoke(messages)

        return {"messages": [llm_response]}

    @staticmethod
    def extract_after_final_answer(text):
        keyword = "FINAL ANSWER: "
        index = text.find(keyword)
        if index != -1:
            return text[index + len(keyword):].strip()
        else:
            return text.strip()

    def run(self, task: dict):
        task_id, question, file_name = task["task_id"], task["question"], task["file_name"]
        print(f"Agent received question (first 100 chars): {question[:100]}...")

        if file_name == "" or file_name is None:
            question_text = question
        else:
            question_text = f'{question} with TASK-ID: {task_id}'

        graph = self._graph_compile()

        max_retries = 3
        base_sleep = 1
        for attempt in range(max_retries):
            try:
                messages: list[HumanMessage] = [HumanMessage(content=question_text)]
                result = graph.invoke({"messages": messages})

                final_text = result["messages"][-1].content
                return self.extract_after_final_answer(final_text)
            except Exception as e:
                sleep_time = base_sleep * (attempt + 1)
                if attempt < max_retries - 1:
                    print(str(e))
                    print(f"Attempt {attempt + 1} failed. Retrying in {sleep_time} seconds...")
                    time.sleep(sleep_time)
                    continue
                return f"Error processing query after {max_retries} attempts: {str(e)}"
        return "This is a default answer."