msanton commited on
Commit
bf9e70e
·
verified ·
1 Parent(s): 81917a3

Add GaiaAgent and tools

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. gaia_agent.py +113 -0
  3. requirements.txt +15 -1
  4. tools.py +223 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /.env
2
+ /chroma_db
3
+ /__pycache__
gaia_agent.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain_core.messages import HumanMessage
4
+ from langchain_chroma import Chroma
5
+ from langchain_litellm import ChatLiteLLM
6
+ from langchain_openai import OpenAIEmbeddings
7
+ from langgraph.graph import START, StateGraph
8
+ from langgraph.graph.message import MessagesState
9
+ from langgraph.prebuilt import ToolNode, tools_condition
10
+ from tools import *
11
+
12
+ load_dotenv()
13
+
14
+ class GaiaAgent:
15
+ def __init__(self):
16
+ self.llm = ChatLiteLLM(
17
+ model="openai/gemini-2.5-pro",
18
+ api_key=os.getenv("ITP_API_KEY"),
19
+ api_base=os.getenv("TRELLIS_URL"),
20
+ temperature=0.5,
21
+ )
22
+ self.tools = [
23
+ web_search,
24
+ wikipedia_search,
25
+ arxiv_search,
26
+ text_splitter,
27
+ read_file,
28
+ analyze_image,
29
+ analyze_audio,
30
+ analyze_youtube_video,
31
+ multiply,
32
+ add,
33
+ subtract,
34
+ divide,
35
+ ]
36
+ self.llm_with_tools = self.llm.bind_tools(self.tools)
37
+ self.system_message = """
38
+ You are a general AI assistant. I will ask you a question.
39
+ Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
40
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
41
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
42
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
43
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
44
+ """
45
+ self.vectorstore = Chroma(
46
+ embedding_function=OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")),
47
+ persist_directory="chroma_db"
48
+ )
49
+ self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 3})
50
+
51
+
52
+ def build_graph(self):
53
+ builder = StateGraph(MessagesState)
54
+ builder.add_node("retriever", self.retrieve_node)
55
+ builder.add_node("assistant", self.assistant_node)
56
+ builder.add_node("tools", ToolNode(self.tools))
57
+
58
+ builder.add_edge(START, "retriever")
59
+ builder.add_edge("retriever", "assistant")
60
+
61
+ builder.add_conditional_edges(
62
+ "assistant",
63
+ tools_condition,
64
+ )
65
+ builder.add_edge("tools", "assistant")
66
+ return builder.compile()
67
+
68
+
69
+ def retrieve_node(self, state: MessagesState):
70
+ """Retriever node"""
71
+ question = state["messages"][-1].content
72
+ docs = self.retriever.invoke(question)
73
+
74
+ if docs:
75
+ context = "\n\n".join([d.page_content for d in docs])
76
+ else:
77
+ context = "No relevant documents found"
78
+
79
+ combined = f"Context:\n{context}\n\nQuestion:\n{question}"
80
+ return {"messages": [HumanMessage(content=combined)]}
81
+
82
+ def assistant_node(self, state: MessagesState):
83
+ """Assistant node"""
84
+ if not any(isinstance(m, HumanMessage) for m in state["messages"]):
85
+ messages = [self.system_message] + state["messages"]
86
+ else:
87
+ messages = state["messages"]
88
+
89
+ response = self.llm_with_tools.invoke(messages)
90
+ return {"messages": [response]}
91
+
92
+ @staticmethod
93
+ def extract_answer(text: str):
94
+ keyword = "FINAL ANSWER"
95
+ index = text.find(keyword)
96
+ if index != -1:
97
+ return text[index + len(keyword):].strip()
98
+ else:
99
+ return text
100
+
101
+ def run(self, task: dict):
102
+ task_id, question, file_name = task["task_id"], task["question"], task["file_name"]
103
+
104
+ if file_name != "" or file_name is not None:
105
+ question = f"{question} with task_id {task_id}"
106
+
107
+ graph = self.build_graph()
108
+
109
+ messages: list[HumanMessage] = [HumanMessage(content=question)]
110
+ result = graph.invoke({"messages": messages})
111
+
112
+ last_message = self.extract_answer(result["messages"][-1].content)
113
+ return self.extract_answer(last_message)
requirements.txt CHANGED
@@ -1,2 +1,16 @@
1
  gradio
2
- requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  gradio
2
+ requests
3
+ langchain
4
+ langchain-community
5
+ langchain-core
6
+ langchain-text-splitters
7
+ langgraph
8
+ langchain-chroma
9
+ langchain-litellm
10
+ langchain_openai
11
+ wikipedia
12
+ python-dotenv
13
+ openai
14
+ arxiv
15
+ chromadb
16
+ openai
tools.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import requests
4
+ import openai
5
+ from typing import List
6
+ from dotenv import load_dotenv
7
+ from langchain_core.tools import tool
8
+ from langchain_community.document_loaders import WebBaseLoader, WikipediaLoader, ImageCaptionLoader, ArxivLoader
9
+ from langchain_community.tools import DuckDuckGoSearchResults
10
+ from langchain_text_splitters import CharacterTextSplitter
11
+
12
+ load_dotenv()
13
+
14
+ @tool
15
+ def multiply(a: int, b: int) -> int:
16
+ """
17
+ Multiply two integers and return the result
18
+
19
+ Args:
20
+ a: The first integer to multiply
21
+ b: The second integer to multiply
22
+
23
+ Returns:
24
+ int: The result of the multiplication
25
+ """
26
+ return a * b
27
+
28
+ @tool
29
+ def add(a: int, b: int) -> int:
30
+ """
31
+ Add two integers and return the result
32
+
33
+ Args:
34
+ a: The first integer to add
35
+ b: The second integer to add
36
+
37
+ Returns:
38
+ int: The result of the addition
39
+ """
40
+ return a + b
41
+
42
+ @tool
43
+ def subtract(a: int, b: int) -> int:
44
+ """
45
+ Subtract two integers and return the result
46
+
47
+ Args:
48
+ a: The first integer to subtract
49
+ b: The second integer to subtract
50
+
51
+ Returns:
52
+ int: The result of the subtraction
53
+ """
54
+ return a - b
55
+
56
+ @tool
57
+ def divide(a: int, b: int) -> int:
58
+ """
59
+ Divide the first integer by the second integer and return the result
60
+
61
+ Args:
62
+ a: The first integer to divide
63
+ b: The second integer to divide
64
+
65
+ Returns:
66
+ int: The result of the division
67
+ """
68
+ return a / b
69
+
70
+ FILE_URL = "https://agents-course-unit4-scoring.hf.space/files/"
71
+
72
+ @tool
73
+ def read_file(task_id: str) -> str:
74
+ """
75
+ Download a file based on the task_id and then read the content of the file
76
+
77
+ Args:
78
+ task_id: The id of the task to download the file from
79
+
80
+ Returns:
81
+ str: The content of the file
82
+ """
83
+ file_url = f"{FILE_URL}{task_id}"
84
+ response = requests.get(file_url, timeout=10, allow_redirects=True)
85
+ with open('temp', 'wb') as fp:
86
+ fp.write(response.content)
87
+ with open('temp') as file:
88
+ return file.read()
89
+
90
+ @tool
91
+ def analyze_image(task_id: str) -> str:
92
+ """
93
+ Analyze an image based on the task_id and return a description of the content of the image
94
+
95
+ Args:
96
+ task_id: The id of the task to analyze the image from
97
+
98
+ Returns:
99
+ str: The description of the content of the image
100
+ """
101
+ file_url = f"{FILE_URL}{task_id}"
102
+ image = ImageCaptionLoader(images=[file_url])
103
+ return image.load()[0].page_content
104
+
105
+ @tool
106
+ def analyze_audio(task_id: str) -> str:
107
+ """
108
+ Analyze an mp3 file based on the task_id and return a description of the content of the audio file
109
+
110
+ Args:
111
+ task_id: The id of the task to analyze the audio file from
112
+
113
+ Returns:
114
+ str: The description of the content of the audio file
115
+ """
116
+ file_url = f"{FILE_URL}{task_id}"
117
+ response = requests.get(file_url, timeout=10, allow_redirects=True)
118
+ temp_file = 'temp.mp3'
119
+ with open(temp_file, 'wb') as fp:
120
+ fp.write(response.content)
121
+ with open(temp_file, "rb") as audio_file:
122
+ transcript = openai.audio.transcriptions.create(
123
+ file=audio_file,
124
+ model="whisper-1"
125
+ )
126
+ return transcript.text
127
+
128
+ @tool
129
+ def analyze_youtube_video(youtube_url: str, question: str) -> str:
130
+ """
131
+ Analyze a youtube video based on the youtube_url and the question and return the answer to the question
132
+
133
+ Args:
134
+ youtube_url: The url of the youtube video to analyze
135
+ question: The question to answer based on the youtube video
136
+
137
+ Returns:
138
+ str: The answer to the question
139
+ """
140
+
141
+
142
+ @tool
143
+ def web_search(query: str) -> str:
144
+ """
145
+ Search the web for the given query and return the results
146
+
147
+ Args:
148
+ query: The query to search the web for
149
+
150
+ Returns:
151
+ str: The text content of the web search results
152
+ """
153
+ search_engine = DuckDuckGoSearchResults(output_type="list", num_results=3)
154
+ results = search_engine.invoke({"query": query})
155
+ page_urls = [url["link"] for url in results]
156
+
157
+ loader = WebBaseLoader(web_paths=page_urls)
158
+ docs = loader.load()
159
+
160
+ combined_text = "\n\n".join(doc.page_content[:15000] for doc in docs)
161
+
162
+ # Clean up excessive newlines, spaces and strip leading/trailing whitespace
163
+ cleaned_text = re.sub(r'\n{3,}', '\n\n', combined_text).strip()
164
+ cleaned_text = re.sub(r'[ \t]{6,}', ' ', cleaned_text)
165
+
166
+ # Strip leading/trailing whitespace
167
+ cleaned_text = cleaned_text.strip()
168
+ return cleaned_text
169
+
170
+ @tool
171
+ def wikipedia_search(query: str) -> str:
172
+ """
173
+ Search Wikipedia articles with the given query and return the pages
174
+
175
+ Args:
176
+ query: The query to search Wikipedia for
177
+
178
+ Returns:
179
+ str: The text content of the Wikipedia articles related to the query
180
+ """
181
+ print("Searching Wikipedia for the query: ", query)
182
+ search_docs = WikipediaLoader(query=query, load_max_docs=3).load()
183
+ formatted_search_docs = "\n\n---\n\n".join(
184
+ [
185
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
186
+ for doc in search_docs
187
+ ])
188
+ return formatted_search_docs
189
+
190
+ @tool
191
+ def arxiv_search(query: str) -> str:
192
+ """
193
+ Search arxiv for the given query and return the results
194
+
195
+ Args:
196
+ query: The query to search arxiv for
197
+
198
+ Returns:
199
+ str: The text content of the arxiv search results
200
+
201
+ """
202
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
203
+ formatted_search_docs = "\n\n---\n\n".join(
204
+ [
205
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
206
+ for doc in search_docs
207
+ ])
208
+ return formatted_search_docs
209
+
210
+ @tool
211
+ def text_splitter(text: str) -> List[str]:
212
+ """
213
+ Split a large text into smaller chunks using Langchain's CharacterTextSplitter
214
+
215
+ Args:
216
+ text: The large text to split into smaller chunks
217
+
218
+ Returns:
219
+ List[str]: a list container the smaller chunks of the text
220
+ """
221
+
222
+ splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=10)
223
+ return splitter.split_text(text)