llamasrock commited on
Commit
09d73b3
·
verified ·
1 Parent(s): b7bb69f

Update app.py

Browse files

Changed to LangGraph agent

Files changed (1) hide show
  1. app.py +193 -38
app.py CHANGED
@@ -3,13 +3,20 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
6
- import os
7
- from smolagents import LiteLLMModel, CodeAgent, GoogleSearchTool
8
- from google import genai
9
- from google.genai import types
10
  import asyncio
 
11
  import requests
12
- from utilities import get_file
 
 
 
 
 
 
 
 
 
 
13
 
14
  # (Keep Constants as is)
15
  # --- Constants ---
@@ -21,50 +28,198 @@ SERPER_API_KEY = os.getenv("SERPER_API_KEY")
21
  # Agent capabilities required: Search the web, listen to audio recordings, watch YouTube videos (process the footage, not the transcript), work with Excel spreadsheets
22
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  class BasicAgent:
26
  def __init__(self):
27
- self.llm_model = LiteLLMModel(
28
- model_id="gemini/gemini-2.5-flash", # you can see other model names here: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. It is important to prefix the name with "gemini/"
29
- api_key=GEMINI_API_KEY,
30
- api_base = 'https://generativelanguage.googleapis.com',
31
- max_tokens=8192
32
  )
33
- # self.google_search_tool = Tool(google_search = GoogleSearch())
34
- self.google_search_tool = GoogleSearchTool()
35
- self.get_file_tool = get_file
36
- self.agent = CodeAgent(model = self.llm_model, tools = [self.google_search_tool, self.get_file_tool])
37
- # # Define Google API client with GoogleSearch tool
38
- # self.client = genai.Client(api_key=GEMINI_API_KEY)
39
 
40
- print("BasicAgent initialized.")
41
-
42
- async def __call__(self, question: str, task_id: str) -> str:
43
- print(f"Agent received question (first 50 chars): {question[:50]}...")
44
- fixed_answer = "This is a default answer."
45
- # print(f"Agent returning fixed answer: {fixed_answer}")
46
- # return fixed_answer
47
 
48
- prompt = f'''You are a general AI assistant. I will ask you a question. Only provide YOUR FINAL ANSWER and nothing else.
49
  YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
50
  If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
51
  If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
52
  If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
53
- {question}'''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- await asyncio.sleep(10)
56
- return self.agent.run(prompt, task_id)
57
- # # Use the Google GenAI client to run the question
58
- # answer = self.client.models.generate_content(
59
- # model='gemini-2.0-flash',
60
- # contents=f'''Answer the following question in the format as requested. If the format not specified, then provide a single word/number/name.
61
- # In your response, do not include anything other than your answer. {question}''',
62
- # config=types.GenerateContentConfig(
63
- # tools=[types.Tool(google_search=types.GoogleSearch()),
64
- # types.Tool(code_execution=self.get_file)]
65
- # )
66
- # )
67
- # return answer.text
68
 
69
 
70
  def run_and_submit_all( profile: gr.OAuthProfile | None):
 
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
 
 
6
  import asyncio
7
+ from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
8
  import requests
9
+ from typing import IO, Dict
10
+ from io import BytesIO
11
+ from langchain_core.messages import HumanMessage, SystemMessage
12
+ from langgraph.graph import MessagesState
13
+ from langgraph.graph import START, StateGraph
14
+ from langgraph.prebuilt import tools_condition
15
+ from langgraph.prebuilt import ToolNode
16
+ from pytube import YouTube
17
+ import base64
18
+ from google.ai.generativelanguage_v1beta.types import Tool as GenAITool
19
+ from google.ai.generativelanguage_v1beta.types import FileData
20
 
21
  # (Keep Constants as is)
22
  # --- Constants ---
 
28
  # Agent capabilities required: Search the web, listen to audio recordings, watch YouTube videos (process the footage, not the transcript), work with Excel spreadsheets
29
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
30
 
31
+ def get_file(task_id: str) -> IO:
32
+ '''
33
+ Downloads the file associated with the given task_id, if one exists and is mapped.
34
+ If the question mentions an attachment, use this function.
35
+ Args:
36
+ task_id: Id of the question.
37
+ Returns:
38
+ The file associated with the question.
39
+ '''
40
+ file_request = requests.get(url=f'https://agents-course-unit4-scoring.hf.space/files/{task_id}')
41
+ file_request.raise_for_status()
42
+
43
+ return BytesIO(file_request.content)
44
+
45
+ def analyse_excel(task_id: str) -> Dict[str, float]:
46
+ '''
47
+ Analyzes the Excel file associated with the given task_id and returns the sum of each numeric column.
48
+ Args:
49
+ task_id: Id of the question.
50
+ Returns:
51
+ A dictionary with the sum of each numeric column.
52
+ '''
53
+ excel_file = get_file(task_id)
54
+ df = pd.read_excel(excel_file, sheet_name=0)
55
+
56
+ return df.select_dtypes(include='number').sum().to_dict()
57
+
58
+ def add_numbers(a: float, b: float) -> float:
59
+ '''
60
+ Adds two numbers together.
61
+ Args:
62
+ a: First number.
63
+ b: Second number.
64
+ Returns:
65
+ The sum of the two numbers.
66
+ '''
67
+ return a + b
68
+
69
+ def transcribe_audio(task_id: str) -> HumanMessage:
70
+ '''
71
+ Opens an audio file and returns its content as a string.
72
+ Args:
73
+ file: The audio file to be opened.
74
+ Returns:
75
+ The content of the audio file as a string.
76
+ '''
77
+ audio_file = get_file(task_id)
78
+ if audio_file is None:
79
+ raise ValueError("No audio file found for the given task_id.")
80
+ # Encode the audio file to base64
81
+ audio_file.seek(0) # Ensure the file pointer is at the beginning
82
+ encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")
83
+
84
+ return HumanMessage(
85
+ content=[
86
+ {"type": "text", "text": "Transcribe the audio."},
87
+ {
88
+ "type": "media",
89
+ "data": encoded_audio, # Use base64 string directly
90
+ "mime_type": "audio/mpeg",
91
+ },
92
+ ]
93
+ )
94
+
95
+ def python_code(task_id: str) -> str:
96
+ '''
97
+ Returns the Python code associated with the given task_id.
98
+ Args:
99
+ task_id: Id of the question.
100
+ Returns:
101
+ The Python code associated with the question.
102
+ '''
103
+ code_request = requests.get(url=f'https://agents-course-unit4-scoring.hf.space/files/{task_id}')
104
+ code_request.raise_for_status()
105
+
106
+ return code_request.text
107
+
108
+ def open_image(task_id: str) -> str:
109
+ '''
110
+ Opens an image file associated with the given task_id.
111
+ Args:
112
+ task_id: Id of the question.
113
+ Returns:
114
+ The base64 encoded string of the image file.
115
+ '''
116
+ image_file = get_file(task_id)
117
+ if image_file is None:
118
+ raise ValueError("No image file found for the given task_id.")
119
+
120
+ return base64.b64encode(image_file.read()).decode("utf-8")
121
+
122
+ def open_youtube_video(url: str) -> HumanMessage:
123
+ '''
124
+ Opens a video file from the given URL.
125
+ Args:
126
+ url: The URL of the video file.
127
+ Returns:
128
+ HumanMessage instructions for the video file.
129
+ '''
130
+ video = FileData(url=url)
131
+
132
+ return HumanMessage(
133
+ content=[
134
+ {"type": "text", "text": "Watch the video and answer the question."},
135
+ {
136
+ "type": "media",
137
+ "data": video,
138
+ "mime_type": "video/mp4",
139
+ },
140
+ ]
141
+ )
142
+
143
+ def google_search(query: str) -> str:
144
+ '''
145
+ Performs a Google search for the given query.
146
+ Args:
147
+ query: The search query.
148
+ Returns:
149
+ The search results as a string.
150
+ '''
151
+ llm = ChatGoogleGenerativeAI(
152
+ model="gemini-2.5-flash-preview-04-17",
153
+ max_tokens=8192,
154
+ temperature=0
155
+ )
156
+ response = llm.invoke(query,
157
+ tools=[GenAITool(google_search={})]
158
+ )
159
+
160
+ return response.content
161
+
162
 
163
  class BasicAgent:
164
  def __init__(self):
165
+ self.llm = ChatGoogleGenerativeAI(
166
+ model="gemini-2.5-flash-preview-04-17",
167
+ max_tokens=8192,
168
+ temperature=0
 
169
  )
170
+ self.tools = [get_file, analyse_excel, add_numbers, transcribe_audio, python_code, open_image, open_youtube_video
171
+ , google_search
172
+ ]
 
 
 
173
 
174
+ self.agent = self.llm.bind_tools(self.tools)
 
 
 
 
 
 
175
 
176
+ self.sys_msg = SystemMessage('''You are a general AI assistant. I will ask you a question. Only provide YOUR FINAL ANSWER and nothing else.
177
  YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
178
  If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
179
  If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
180
  If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
181
+ You have access to multiple tools and should use as many as you need to answer the question.
182
+ If you are asked to analyze an Excel file, use the 'analyse_excel' tool.
183
+ If you are asked to download a file, use the 'get_file' tool.
184
+ If you are asked to add two numbers, use the 'add_numbers' tool. If you need to add more than two numbers, use the 'add_numbers'
185
+ tool multiple times.
186
+ If you are asked to transcribe an audio file, use the 'transcribe_audio' tool.
187
+ If you are asked to run a Python code, use the 'python_code' tool.
188
+ If you are asked to open an image, use the 'open_image' tool.
189
+ If you are asked to open a YouTube video, use the 'open_video' tool.
190
+ If the question requires a web search because your internal knowledge doesn't have the information, use the 'google_search' tool.
191
+ ''')
192
+
193
+ # Graph
194
+ self.builder = StateGraph(MessagesState)
195
+
196
+ # Define nodes: these do the work
197
+ self.builder.add_node("assistant", self.assistant)
198
+ self.builder.add_node("tools", ToolNode(self.tools))
199
+
200
+ # Define edges: these determine how the control flow moves
201
+ self.builder.add_edge(START, "assistant")
202
+ self.builder.add_conditional_edges(
203
+ "assistant",
204
+ # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
205
+ # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
206
+ tools_condition,
207
+ )
208
+ self.builder.add_edge("tools", "assistant")
209
+ self.react_graph = self.builder.compile()
210
 
211
+ print("BasicAgent initialized.")
212
+
213
+ def assistant(self, state: MessagesState):
214
+ return {"messages": [self.agent.invoke([self.sys_msg] + state["messages"])]}
215
+
216
+ async def __call__(self, question: str, task_id: str) -> str:
217
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
218
+ fixed_answer = "This is a default answer."
219
+
220
+ await asyncio.sleep(4)
221
+ messages = self.react_graph.invoke({"messages": f'Task id: {task_id}\n {question}'})
222
+ return messages["messages"][-1].content if messages["messages"] else fixed_answer
 
223
 
224
 
225
  def run_and_submit_all( profile: gr.OAuthProfile | None):