sqfoo commited on
Commit
e1dc6ad
·
verified ·
1 Parent(s): 3a738a7

Upload 2 files

Browse files
Files changed (2) hide show
  1. gemini_agent.py +131 -0
  2. tools.py +327 -0
gemini_agent.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from dotenv import load_dotenv
4
+ from typing import TypedDict, Annotated, Optional
5
+
6
+ from langgraph.prebuilt import ToolNode, tools_condition
7
+ from langgraph.graph import StateGraph, START
8
+ from langgraph.graph.message import add_messages
9
+
10
+ from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
11
+ from langchain_google_genai import ChatGoogleGenerativeAI
12
+
13
+ from tools import *
14
+
15
+ load_dotenv()
16
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
17
+
18
+ class AgentState(TypedDict):
19
+ """Agent state for the graph."""
20
+ input_file: Optional[str]
21
+ messages: Annotated[list[AnyMessage], add_messages]
22
+
23
+
24
+ class GEMINI_AGENT:
25
+ def __init__(self):
26
+ self.llm = ChatGoogleGenerativeAI(
27
+ model="gemini-2.0-flash-lite",
28
+ temperature=0,
29
+ max_tokens=1024,
30
+ google_api_key=os.getenv("GEMINI_API_KEY"),
31
+ )
32
+
33
+ self.tools = [
34
+ duckduck_websearch,
35
+ serper_websearch,
36
+ visit_webpage,
37
+ wiki_search,
38
+ youtube_viewer,
39
+ text_splitter,
40
+ read_file,
41
+ excel_read,
42
+ csv_read,
43
+ mp3_listen,
44
+ image_caption,
45
+ run_python,
46
+ multiply,
47
+ add,
48
+ subtract,
49
+ divide
50
+ ]
51
+
52
+ self.llm_with_tools = self.llm.bind_tools(self.tools)
53
+ self.app = self._graph_compile()
54
+
55
+ def _graph_compile(self):
56
+ builder = StateGraph(AgentState)
57
+ # Define nodes: these do the work
58
+ builder.add_node("assistant", self._assistant)
59
+ builder.add_node("tools", ToolNode(self.tools))
60
+ # Define edges: these determine how the control flow moves
61
+ builder.add_edge(START, "assistant")
62
+ builder.add_conditional_edges(
63
+ "assistant",
64
+ tools_condition,
65
+ )
66
+ builder.add_edge("tools", "assistant")
67
+ react_graph = builder.compile()
68
+ return react_graph
69
+
70
+ def _assistant(self, state: AgentState):
71
+ sys_msg = SystemMessage(
72
+ content=
73
+ """
74
+ You are a helpful assistant tasked with answering questions using a set of tools. When given a question, follow these steps:
75
+ 1. Create a clear, step-by-step plan to solve the question.
76
+ 2. If a tool is necessary, select the most appropriate tool based on its functionality. If one tool isn't working, use another with similar functionality.
77
+ 3. Execute your plan and provide the response in the following format:
78
+
79
+ FINAL ANSWER: [YOUR FINAL ANSWER]
80
+
81
+ Your final answer should be:
82
+
83
+ - A number (without commas or units unless explicitly requested),
84
+ - A short string (avoid articles, abbreviations, and use plain text for digits unless otherwise specified),
85
+ - A comma-separated list (apply the formatting rules above for each element, with exactly one space after each comma).
86
+
87
+ Ensure that your answer is concise and follows the task instructions strictly. If the answer is more complex, break it down in a way that follows the format.
88
+ Begin your response with "FINAL ANSWER: " followed by the answer, and nothing else.
89
+ """
90
+ )
91
+
92
+ return {
93
+ "messages": [self.llm_with_tools.invoke([sys_msg] + state["messages"])],
94
+ "input_file": state["input_file"]
95
+ }
96
+
97
+ def extract_after_final_answer(self, text):
98
+ keyword = "FINAL ANSWER: "
99
+ index = text.find(keyword)
100
+ if index != -1:
101
+ return text[index + len(keyword):]
102
+ else:
103
+ return ""
104
+
105
+ def run(self, task: dict):
106
+ task_id, question, file_name = task["task_id"], task["question"], task["file_name"]
107
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
108
+
109
+ if file_name == "" or file_name is None:
110
+ question_text = question
111
+ else:
112
+ question_text = f'{question} with TASK-ID: {task_id}'
113
+ messages = [HumanMessage(content=question_text)]
114
+
115
+ max_retries = 5
116
+ base_sleep = 1
117
+ for attempt in range(max_retries):
118
+ try:
119
+ response = self.app.invoke({"messgae": messages, "input_file": None})
120
+ final_ans = self.extract_after_final_answer(response)
121
+ time.sleep(60) # avoid rate limit
122
+ return final_ans
123
+ except Exception as e:
124
+ sleep_time = base_sleep * (attempt + 1)
125
+ if attempt < max_retries - 1:
126
+ print(str(e))
127
+ print(f"Attempt {attempt + 1} failed. Retrying in {sleep_time} seconds...")
128
+ time.sleep(sleep_time)
129
+ continue
130
+ return f"Error processing query after {max_retries} attempts: {str(e)}"
131
+ return "This is a default answer."
tools.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import requests
4
+ import pandas as pd
5
+ from typing import List
6
+ from dotenv import load_dotenv
7
+
8
+ from google import genai
9
+ from google.genai import types
10
+
11
+ from langchain_core.tools import tool
12
+ from langchain.document_loaders import WebBaseLoader
13
+ from langchain_experimental.tools import PythonREPLTool
14
+ from langchain.text_splitter import CharacterTextSplitter
15
+ from langchain_community.tools import DuckDuckGoSearchResults
16
+ from langchain_community.retrievers import WikipediaRetriever
17
+ from langchain_community.utilities import GoogleSerperAPIWrapper
18
+ from langchain_community.document_loaders import ImageCaptionLoader, AssemblyAIAudioTranscriptLoader
19
+
20
+
21
+ load_dotenv()
22
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
23
+
24
+
25
+ @tool
26
+ def duckduck_websearch(query: str) -> str:
27
+ """
28
+ Performs a web search using the given query, downloads the content of two relevant web pages,
29
+ and returns their combined content as a raw string.
30
+
31
+ This is useful when the task requires analysis of web page content, such as retrieving poems,
32
+ changelogs, or other textual resources.
33
+
34
+ Args:
35
+ query (str): The search query.
36
+
37
+ Returns:
38
+ str: The combined raw text content of the two retrieved web pages.
39
+ """
40
+ search_engine = DuckDuckGoSearchResults(output_format="list", num_results=2)
41
+ page_urls = [url["link"] for url in search_engine(query)]
42
+
43
+ loader = WebBaseLoader(web_paths=(page_urls))
44
+ docs = loader.load()
45
+
46
+ combined_text = "\n\n".join(doc.page_content[:15000] for doc in docs)
47
+
48
+ # Clean up excessive newlines, spaces and strip leading/trailing whitespace
49
+ cleaned_text = re.sub(r'\n{3,}', '\n\n', combined_text).strip()
50
+ cleaned_text = re.sub(r'[ \t]{6,}', ' ', cleaned_text)
51
+
52
+ # Strip leading/trailing whitespace
53
+ cleaned_text = cleaned_text.strip()
54
+ return cleaned_text
55
+
56
+
57
+ @tool
58
+ def serper_websearch(query: str) -> str:
59
+ """
60
+ Performs a web search using the given query with SERPER Search Engine
61
+
62
+ Args:
63
+ query (str): The search query.
64
+
65
+ Returns:
66
+ str: the search result
67
+ """
68
+ search = GoogleSerperAPIWrapper(serper_api_key=os.getenv("SERPER_API_KEY"))
69
+ results = search.run(query)
70
+ return results
71
+
72
+ @tool
73
+ def visit_webpage(url: str) -> str:
74
+ """
75
+ Fetches raw HTML content of a web page.
76
+
77
+ Args:
78
+ url: the webpage url
79
+
80
+ Returns:
81
+ str: The combined raw text content of the webpage
82
+ """
83
+ try:
84
+ response = requests.get(url, timeout=5)
85
+ return response.text[:5000]
86
+ except Exception as e:
87
+ return f"[ERROR fetching {url}]: {str(e)}"
88
+
89
+ @tool
90
+ def wiki_search(query: str) -> str:
91
+ """
92
+ Searches for a Wikipedia articles using the provided query and returns the content of the corresponding Wikipedia pages.
93
+
94
+ Args:
95
+ query (str): The search term to look up on Wikipedia.
96
+
97
+ Returns:
98
+ str: The text content of the Wikipedia articles related to the query.
99
+ """
100
+ retriever = WikipediaRetriever()
101
+ docs = retriever.invoke(query)
102
+ combined_text = "\n\n".join(doc.page_content for doc in docs)
103
+ return combined_text
104
+
105
+ @tool
106
+ def youtube_viewer(youtube_url: str, question: str) -> str:
107
+ """
108
+ Analyzes a YouTube video from the provided URL and returns an answer
109
+ to the given question based on the analysis results.
110
+
111
+ Args:
112
+ youtube_url (str): The URL of the YouTube video, in the format
113
+ "https://www.youtube.com/...".
114
+ question (str): A question related to the content of the video.
115
+
116
+ Returns:
117
+ str: An answer to the question based on the video's content.
118
+ """
119
+ client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
120
+ response = client.models.generate_content(
121
+ model='models/gemini-2.5-flash-preview-04-17',
122
+ contents=types.Content(
123
+ parts=[
124
+ types.Part(
125
+ file_data=types.FileData(file_uri=youtube_url)
126
+ ),
127
+ types.Part(text=question)
128
+ ]
129
+ )
130
+ )
131
+ return response.text
132
+
133
+ @tool
134
+ def text_splitter(text: str) -> List[str]:
135
+ """
136
+ Splits text into chunks using LangChain's CharacterTextSplitter.
137
+
138
+ Args:
139
+ text: A string of text to split.
140
+
141
+ Returns:
142
+ List[str]: a list of split text
143
+ """
144
+ splitter = CharacterTextSplitter(chunk_size=450, chunk_overlap=10)
145
+ return splitter.split_text(text)
146
+
147
+ @tool
148
+ def read_file(task_id: str) -> str:
149
+ """
150
+ First download the file, then read its content
151
+
152
+ Args:
153
+ dir: the task_id
154
+
155
+ Returns:
156
+ str: the file content
157
+ """
158
+ file_url = f'{DEFAULT_API_URL}/files/{task_id}'
159
+ r = requests.get(file_url, timeout=15, allow_redirects=True)
160
+ with open('temp', "wb") as fp:
161
+ fp.write(r.content)
162
+ with open('temp') as f:
163
+ return f.read()
164
+
165
+ @tool
166
+ def excel_read(task_id: str) -> str:
167
+ """
168
+ First download the excel file, then read its content
169
+
170
+ Args:
171
+ dir: the task_id
172
+
173
+ Returns:
174
+ str: the content of excel file
175
+ """
176
+ try:
177
+ file_url = f'{DEFAULT_API_URL}/files/{task_id}'
178
+ r = requests.get(file_url, timeout=15, allow_redirects=True)
179
+ with open('temp.xlsx', "wb") as fp:
180
+ fp.write(r.content)
181
+ # Read the Excel file
182
+ df = pd.read_excel('temp.xlsx')
183
+ # Run various analyses based on the query
184
+ result = (
185
+ f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
186
+ )
187
+ result += f"Columns: {', '.join(df.columns)}\n\n"
188
+ # Add summary statistics
189
+ result += "Summary statistics:\n"
190
+ result += str(df.describe())
191
+ return result
192
+ except Exception as e:
193
+ return f"Error analyzing Excel file: {str(e)}"
194
+
195
+ @tool
196
+ def csv_read(task_id: str) -> str:
197
+ """
198
+ First download the csv file, then read its content
199
+
200
+ Args:
201
+ dir: the task_id
202
+
203
+ Returns:
204
+ str: the content of csv file
205
+ """
206
+ try:
207
+ file_url = f'{DEFAULT_API_URL}/files/{task_id}'
208
+ r = requests.get(file_url, timeout=15, allow_redirects=True)
209
+ with open('temp.csv', "wb") as fp:
210
+ fp.write(r.content)
211
+ # Read the CSV file
212
+ df = pd.read_csv('temp.csv')
213
+ # Run various analyses based on the query
214
+ result = (
215
+ f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
216
+ )
217
+ result += f"Columns: {', '.join(df.columns)}\n\n"
218
+ # Add summary statistics
219
+ result += "Summary statistics:\n"
220
+ result += str(df.describe())
221
+ return result
222
+ except Exception as e:
223
+ return f"Error analyzing CSV file: {str(e)}"
224
+
225
+
226
+ @tool
227
+ def mp3_listen(task_id: str) -> str:
228
+ """
229
+ First download the mp3 file, then listen to it
230
+
231
+ Args:
232
+ dir: the task_id
233
+
234
+ Returns:
235
+ str: the content of mp3 file
236
+ """
237
+ file_url = f'{DEFAULT_API_URL}/files/{task_id}'
238
+ r = requests.get(file_url, timeout=15, allow_redirects=True)
239
+ with open('temp.mp3', "wb") as fp:
240
+ fp.write(r.content)
241
+ loader = AssemblyAIAudioTranscriptLoader(file_path="temp.mp3", api_key=os.getenv("AssemblyAI_API_KEY"))
242
+ docs = loader.load()
243
+ contents = [doc.page_content for doc in docs]
244
+ return "\n".join(contents)
245
+
246
+
247
+ @tool
248
+ def image_caption(dir: str) -> str:
249
+ """
250
+ Understand the content of the provided image
251
+
252
+ Args:
253
+ dir: the image url link
254
+
255
+ Returns:
256
+ str: the image caption
257
+ """
258
+ loader = ImageCaptionLoader(images=[dir])
259
+ metadata = loader.load()
260
+ return metadata[0].page_content
261
+
262
+
263
+ @tool
264
+ def run_python(code: str):
265
+ """ Run the given python code
266
+
267
+ Args:
268
+ code: the python code
269
+ """
270
+ return PythonREPLTool().run(code)
271
+
272
+ @tool
273
+ def multiply(a: float, b: float) -> float:
274
+ """
275
+ Multiply two numbers.
276
+
277
+ Args:
278
+ a: first float
279
+ b: second float
280
+
281
+ Returns:
282
+ float: the multiplication of a and b
283
+ """
284
+ return a * b
285
+
286
+ @tool
287
+ def add(a: float, b: float) -> float:
288
+ """
289
+ Add two numbers.
290
+
291
+ Args:
292
+ a: first float
293
+ b: second float
294
+
295
+ Returns:
296
+ float: the sum of a and b
297
+ """
298
+ return a + b
299
+
300
+ @tool
301
+ def subtract(a: float, b: float) -> float:
302
+ """
303
+ Subtract two numbers.
304
+
305
+ Args:
306
+ a: first float
307
+ b: second float
308
+
309
+ Returns:
310
+ float: the result after a subtracted by b
311
+ """
312
+ return a - b
313
+
314
+ @tool
315
+ def divide(a: float, b: float) -> float:
316
+ """Divide two numbers.
317
+
318
+ Args:
319
+ a: first float
320
+ b: second float
321
+
322
+ Returns:
323
+ float: the result after a divided by b
324
+ """
325
+ if b == 0:
326
+ raise ValueError("Cannot divide by zero.")
327
+ return a / b