Spaces:
Sleeping
Sleeping
| import time | |
| from langchain.chains import RetrievalQA | |
| from langchain_chroma import Chroma | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langgraph.graph import StateGraph, START | |
| from langgraph.graph.message import MessagesState | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from agent_tools import * | |
| load_dotenv() | |
| sys_msg = SystemMessage( | |
| content= | |
| """ | |
| You are a helpful assistant tasked with answering questions using a set of tools. When given a question, follow these steps: | |
| 1. Create a clear, step-by-step plan to solve the question. | |
| 2. If a tool is necessary, select the most appropriate tool based on its functionality. If one tool isn't working, use another with similar functionality. | |
| 3. If a question depends on external numeric or factual data not provided, automatically use your search tools to find it online before answering. | |
| 4. Base your answer on tool outputs and any provided files. | |
| 5. Execute your plan and provide the response in the following format: | |
| FINAL ANSWER: [YOUR FINAL ANSWER] | |
| Your final answer should be: | |
| - A number (without commas or units unless explicitly requested), | |
| - A short string (avoid articles, abbreviations, and use plain text for digits unless otherwise specified), | |
| - A comma-separated list (apply the formatting rules above for each element, with exactly one space after each comma). | |
| Ensure that your answer is concise and follows the task instructions strictly. If the answer is more complex, break it down in a way that follows the format. | |
| Begin your response with "FINAL ANSWER: " followed by the answer, and nothing else. | |
| """ | |
| ) | |
| class CUSTOM_AGENT: | |
| """ | |
| A simple deterministic agent that leverages our tools directly and avoids | |
| LLM refusal fallbacks. | |
| """ | |
| def __init__(self): | |
| self.llm = ChatOpenAI(model="gpt-5", api_key=os.getenv("OPENAI_API_KEY"), temperature=0) | |
| self.tools = TOOLS | |
| self.llm_with_tools = self.llm.bind_tools(self.tools) | |
| self.sys_msg = sys_msg | |
| embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")) | |
| persist_directory = "chroma_db" | |
| self.vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings) | |
| self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 3}) | |
| self.qa_chain = RetrievalQA.from_chain_type( | |
| llm=self.llm, | |
| retriever=self.retriever, | |
| return_source_documents=True | |
| ) | |
| def _graph_compile(self): | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("retriever", self._retriever_node) | |
| builder.add_node("assistant", self._assistant) | |
| builder.add_node("tools", ToolNode(self.tools)) | |
| builder.add_edge(START, "retriever") | |
| builder.add_edge("retriever", "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition, | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| return builder.compile() | |
| def _retriever_node(self, state: MessagesState): | |
| """Retriever node""" | |
| question = state["messages"][ -1 ].content | |
| docs = self.retriever.invoke(question) | |
| if docs: | |
| context = "\n\n".join([d.page_content for d in docs]) | |
| else: | |
| context = "No relevant documents found" | |
| combined = f"Context:\n{context}\n\nQuestion:\n{question}" | |
| return {"messages": [HumanMessage(content=combined)]} | |
| def _assistant(self, state: MessagesState): | |
| """Assistant node""" | |
| if not any(isinstance(m, SystemMessage) for m in state["messages"]): | |
| messages = [self.sys_msg] + state["messages"] | |
| else: | |
| messages = state["messages"] | |
| llm_response = self.llm_with_tools.invoke(messages) | |
| return {"messages": [llm_response]} | |
| def extract_after_final_answer(text): | |
| keyword = "FINAL ANSWER: " | |
| index = text.find(keyword) | |
| if index != -1: | |
| return text[index + len(keyword):].strip() | |
| else: | |
| return text.strip() | |
| def run(self, task: dict): | |
| task_id, question, file_name = task["task_id"], task["question"], task["file_name"] | |
| print(f"Agent received question (first 100 chars): {question[:100]}...") | |
| if file_name == "" or file_name is None: | |
| question_text = question | |
| else: | |
| question_text = f'{question} with TASK-ID: {task_id}' | |
| graph = self._graph_compile() | |
| max_retries = 3 | |
| base_sleep = 1 | |
| for attempt in range(max_retries): | |
| try: | |
| messages: list[HumanMessage] = [HumanMessage(content=question_text)] | |
| result = graph.invoke({"messages": messages}) | |
| final_text = result["messages"][-1].content | |
| return self.extract_after_final_answer(final_text) | |
| except Exception as e: | |
| sleep_time = base_sleep * (attempt + 1) | |
| if attempt < max_retries - 1: | |
| print(str(e)) | |
| print(f"Attempt {attempt + 1} failed. Retrying in {sleep_time} seconds...") | |
| time.sleep(sleep_time) | |
| continue | |
| return f"Error processing query after {max_retries} attempts: {str(e)}" | |
| return "This is a default answer." | |