Spaces:
Runtime error
Runtime error
| import os | |
| from dotenv import load_dotenv | |
| from langchain_core.messages import HumanMessage | |
| from langchain_chroma import Chroma | |
| from langchain_litellm import ChatLiteLLM | |
| from langchain_openai import OpenAIEmbeddings | |
| from langgraph.graph import START, StateGraph | |
| from langgraph.graph.message import MessagesState | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from tools import * | |
| load_dotenv() | |
| class GaiaAgent: | |
| def __init__(self): | |
| self.llm = ChatLiteLLM( | |
| model="openai/gemini-2.5-pro", | |
| api_key=os.getenv("ITP_API_KEY"), | |
| api_base=os.getenv("TRELLIS_URL"), | |
| temperature=0.5, | |
| ) | |
| self.tools = [ | |
| web_search, | |
| wikipedia_search, | |
| arxiv_search, | |
| text_splitter, | |
| read_file, | |
| analyze_image, | |
| analyze_audio, | |
| analyze_youtube_video, | |
| multiply, | |
| add, | |
| subtract, | |
| divide, | |
| ] | |
| self.llm_with_tools = self.llm.bind_tools(self.tools) | |
| self.system_message = """ | |
| You are a general AI assistant. I will ask you a question. | |
| Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
| If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. | |
| If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. | |
| If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
| """ | |
| self.vectorstore = Chroma( | |
| embedding_function=OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")), | |
| persist_directory="chroma_db" | |
| ) | |
| self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 3}) | |
| def build_graph(self): | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("retriever", self.retrieve_node) | |
| builder.add_node("assistant", self.assistant_node) | |
| builder.add_node("tools", ToolNode(self.tools)) | |
| builder.add_edge(START, "retriever") | |
| builder.add_edge("retriever", "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition, | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| return builder.compile() | |
| def retrieve_node(self, state: MessagesState): | |
| """Retriever node""" | |
| question = state["messages"][-1].content | |
| docs = self.retriever.invoke(question) | |
| if docs: | |
| context = "\n\n".join([d.page_content for d in docs]) | |
| else: | |
| context = "No relevant documents found" | |
| combined = f"Context:\n{context}\n\nQuestion:\n{question}" | |
| return {"messages": [HumanMessage(content=combined)]} | |
| def assistant_node(self, state: MessagesState): | |
| """Assistant node""" | |
| if not any(isinstance(m, HumanMessage) for m in state["messages"]): | |
| messages = [self.system_message] + state["messages"] | |
| else: | |
| messages = state["messages"] | |
| response = self.llm_with_tools.invoke(messages) | |
| return {"messages": [response]} | |
| def extract_answer(text: str): | |
| keyword = "FINAL ANSWER" | |
| index = text.find(keyword) | |
| if index != -1: | |
| return text[index + len(keyword):].strip() | |
| else: | |
| return text | |
| def run(self, task: dict): | |
| task_id, question, file_name = task["task_id"], task["question"], task["file_name"] | |
| if file_name != "" or file_name is not None: | |
| question = f"{question} with task_id {task_id}" | |
| graph = self.build_graph() | |
| messages: list[HumanMessage] = [HumanMessage(content=question)] | |
| result = graph.invoke({"messages": messages}) | |
| last_message = self.extract_answer(result["messages"][-1].content) | |
| return self.extract_answer(last_message) |