kokluch commited on
Commit
201adcb
·
1 Parent(s): 29b3ab5

Add Tools

Browse files
Files changed (4) hide show
  1. agent.py +281 -0
  2. app.py +31 -9
  3. requirements.txt +19 -2
  4. wikipedia_tool.py +52 -0
agent.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import re
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import TypedDict, Annotated, Optional
7
+ import pandas as pd
8
+ import requests
9
+ from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
10
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
11
+ from langchain_core.tools import tool
12
+ from langchain_google_genai import ChatGoogleGenerativeAI
13
+ from langchain_tavily import TavilySearch
14
+ from langgraph.graph import START, StateGraph
15
+ from langgraph.graph.message import add_messages
16
+ from langgraph.prebuilt import ToolNode
17
+ from langgraph.prebuilt import tools_condition
18
+ from mediawikiapi import MediaWikiAPI
19
+ from transformers import pipeline
20
+ from wikipedia_tool import WikipediaTool
21
+
22
+ @tool
23
+ def read_xlsx_file(file_path: str) -> str:
24
+ """
25
+ Read a XLSX file using pandas and returns its content.
26
+
27
+ Args:
28
+ file_path: Path to the XLSX file
29
+
30
+ Returns:
31
+ Content of XLSX file as markdown or error message
32
+ """
33
+ try:
34
+ # Read the CSV file
35
+ df = pd.read_excel(file_path)
36
+ return df.to_markdown()
37
+
38
+ except ImportError:
39
+ return "Error: pandas is not installed. Please install it with 'pip install pandas'."
40
+ except Exception as e:
41
+ return f"Error analyzing CSV file: {str(e)}"
42
+
43
+ @tool
44
+ def addition(a: int, b: int) -> int:
45
+ """
46
+ Add two int numbers.
47
+
48
+ Args:
49
+ a: int
50
+ b int
51
+
52
+ Returns:
53
+ a + b
54
+ """
55
+ return a + b
56
+
57
+ @tool
58
+ def multiple(a: int, b: int) -> float:
59
+ """
60
+ Multiple two float numbers.
61
+
62
+ Args:
63
+ a: int
64
+ b int
65
+
66
+ Returns:
67
+ a * b
68
+ """
69
+ return a * b
70
+
71
+ class Agent:
72
+ def __init__(self):
73
+
74
+ llm = ChatGoogleGenerativeAI(
75
+ model="gemini-2.5-flash-preview-04-17",
76
+ # model="gemini-2.0-flash",
77
+ # model="gemini-1.5-pro",
78
+ temperature=0
79
+ )
80
+
81
+ self.tools = [
82
+ WikipediaTool(api_wrapper=WikipediaAPIWrapper(wiki_client=MediaWikiAPI())),
83
+ TavilySearch(),
84
+ read_xlsx_file,
85
+ addition,
86
+ multiple
87
+ ]
88
+
89
+ self.llm_with_tools = llm.bind_tools(self.tools)
90
+
91
+ self.graph = self.build_graph()
92
+
93
+ def build_graph(self):
94
+
95
+ class AgentState(TypedDict):
96
+ messages: Annotated[list[AnyMessage], add_messages]
97
+ task_id: str
98
+ file_name: Optional[str]
99
+
100
+ def assistant(state: AgentState):
101
+ try:
102
+ messages = state.get("messages")
103
+
104
+ # Invoke the LLM with tools
105
+ response = self.llm_with_tools.invoke(messages)
106
+
107
+ # Ensure we return the response in the correct format
108
+ return {
109
+ "messages": [response]
110
+ }
111
+ except Exception as e:
112
+ # Create an error message if something goes wrong
113
+ error_msg = AIMessage(content=f"Sorry, I encountered an error: {str(e)}")
114
+ return {
115
+ "messages": [error_msg]
116
+ }
117
+
118
+ def download_file_if_any(state: AgentState) -> str:
119
+ if state.get("file_name"):
120
+ return "download_file"
121
+ else:
122
+ return "assistant"
123
+
124
+ def download_file(state: AgentState):
125
+ filename = state.get("file_name")
126
+ task_id = state.get("task_id")
127
+ url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
128
+
129
+ try:
130
+ # Send a GET request to the URL
131
+ response = requests.get(url, stream=True)
132
+ # Ensure the request was successful
133
+ response.raise_for_status()
134
+
135
+ # Create a temporary file
136
+ temp_dir = tempfile.gettempdir() # Get the temporary directory path
137
+ temp_file_path = os.path.join(temp_dir, os.path.basename(filename))
138
+
139
+ # Open a local file in binary write mode
140
+ with open(temp_file_path, 'wb') as file:
141
+ # Write the content of the response to the file
142
+ for chunk in response.iter_content(chunk_size=8192):
143
+ file.write(chunk)
144
+
145
+ return {}
146
+
147
+ except requests.exceptions.RequestException as e:
148
+ error_msg = AIMessage(content=f"Sorry, I encountered an error: {str(e)}")
149
+ return {
150
+ "messages": [error_msg]
151
+ }
152
+
153
+ def file_condition(state: AgentState) -> str:
154
+ filename = state.get("file_name")
155
+ suffix = Path(filename).suffix
156
+ if suffix in [".png", ".jpeg"]:
157
+ return "add_image_message"
158
+ elif suffix in [".xlsx"]:
159
+ return "add_xlsx_message"
160
+ elif suffix in [".mp3"]:
161
+ return "add_audio_message"
162
+ elif suffix in [".py"]:
163
+ return "add_py_message"
164
+ else:
165
+ return "assistant"
166
+
167
+
168
+ def add_image_message(state: AgentState):
169
+ filename = state.get("file_name")
170
+ temp_dir = tempfile.gettempdir() # Get the temporary directory path
171
+ image_path = os.path.join(temp_dir, os.path.basename(filename))
172
+ # Load the image and convert it to base64
173
+ with open(image_path, "rb") as img_file:
174
+ base64_image = base64.b64encode(img_file.read()).decode("utf-8")
175
+
176
+ # Construct the image message
177
+ image_message = HumanMessage(content=[{
178
+ "type": "image_url",
179
+ "image_url": {
180
+ "url": f"data:image/jpeg;base64,{base64_image}"
181
+ }
182
+ }])
183
+
184
+ return { "messages" : state.get("messages") + [image_message] }
185
+
186
+ def add_xlsx_message(state: AgentState):
187
+ filename = state.get("file_name")
188
+ temp_dir = tempfile.gettempdir() # Get the temporary directory path
189
+ xlsx_path = os.path.join(temp_dir, os.path.basename(filename))
190
+
191
+ # Construct the message
192
+ xlsx_message = HumanMessage(content=f"xlsx file is at {xlsx_path}")
193
+
194
+ return { "messages" : state.get("messages") + [xlsx_message] }
195
+
196
+ def add_audio_message(state: AgentState):
197
+ filename = state.get("file_name")
198
+ temp_dir = tempfile.gettempdir() # Get the temporary directory path
199
+ audio_path = os.path.join(temp_dir, os.path.basename(filename))
200
+
201
+ pipe = pipeline(
202
+ task="automatic-speech-recognition",
203
+ model="openai/whisper-large-v3"
204
+ )
205
+
206
+ result = pipe(audio_path)
207
+
208
+ audio_message = HumanMessage(result["text"])
209
+
210
+ return {"messages": state.get("messages") + [audio_message]}
211
+
212
+ def add_py_message(state: AgentState):
213
+ filename = state.get("file_name")
214
+ temp_dir = tempfile.gettempdir() # Get the temporary directory path
215
+ file_path = os.path.join(temp_dir, os.path.basename(filename))
216
+
217
+ with open(file_path, 'r') as file:
218
+ content = file.read()
219
+
220
+ py_message = HumanMessage(content=[{
221
+ "type": "text",
222
+ "text": content
223
+ }])
224
+ return {"messages": state.get("messages") + [py_message]}
225
+
226
+ ## The graph
227
+ builder = StateGraph(AgentState)
228
+
229
+ # Define nodes: these do the work
230
+ builder.add_node("assistant", assistant)
231
+ builder.add_node("tools", ToolNode(self.tools))
232
+ builder.add_node("download_file", download_file)
233
+ builder.add_node("add_image_message", add_image_message)
234
+ builder.add_node("add_xlsx_message", add_xlsx_message)
235
+ builder.add_node("add_py_message", add_py_message)
236
+ builder.add_node("add_audio_message", add_audio_message)
237
+
238
+ # Define edges: these determine how the control flow moves
239
+ builder.add_conditional_edges(
240
+ START,
241
+ download_file_if_any
242
+ )
243
+ # builder.add_edge("download_file", "assistant")
244
+ builder.add_conditional_edges(
245
+ "download_file",
246
+ file_condition
247
+ )
248
+ builder.add_edge("add_image_message", "assistant")
249
+ builder.add_edge("add_xlsx_message", "assistant")
250
+ builder.add_edge("add_py_message", "assistant")
251
+ builder.add_edge("add_audio_message", "assistant")
252
+ builder.add_conditional_edges(
253
+ "assistant",
254
+ # If the latest message requires a tool, route to tools
255
+ # Otherwise, provide a direct response
256
+ tools_condition
257
+ )
258
+ builder.add_edge("tools", "assistant")
259
+ return builder.compile()
260
+
261
+ def run(self, question: str, task_id: str, file_name: str | None):
262
+ system_prompt = SystemMessage(content="You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, use digit not letter, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.")
263
+
264
+ messages = [system_prompt, HumanMessage(content=question)]
265
+
266
+ response = self.graph.invoke({"messages": messages, "task_id": task_id, "file_name": file_name}, debug=True)
267
+
268
+ answer = response['messages'][-1].content
269
+
270
+ for m in response['messages']:
271
+ m.pretty_print()
272
+
273
+ # Regex to capture text after "FINAL ANSWER: "
274
+ match = re.search(r'FINAL ANSWER:\s*(.*)', answer)
275
+
276
+ if match:
277
+ final_answer = match.group(1)
278
+ print(final_answer)
279
+ return final_answer
280
+
281
+ return answer
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
 
2
  import gradio as gr
3
  import requests
4
- import inspect
5
  import pandas as pd
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
@@ -12,12 +13,29 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  class BasicAgent:
14
  def __init__(self):
15
- print("BasicAgent initialized.")
16
- def __call__(self, question: str) -> str:
 
 
 
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
@@ -44,7 +62,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
44
  except Exception as e:
45
  print(f"Error instantiating agent: {e}")
46
  return f"Error initializing agent: {e}", None
47
- # In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
48
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
49
  print(agent_code)
50
 
@@ -54,6 +72,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
54
  response = requests.get(questions_url, timeout=15)
55
  response.raise_for_status()
56
  questions_data = response.json()
 
57
  if not questions_data:
58
  print("Fetched questions list is empty.")
59
  return "Fetched questions list is empty or invalid format.", None
@@ -76,11 +95,12 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
76
  for item in questions_data:
77
  task_id = item.get("task_id")
78
  question_text = item.get("question")
 
79
  if not task_id or question_text is None:
80
  print(f"Skipping item with missing task_id or question: {item}")
81
  continue
82
  try:
83
- submitted_answer = agent(question_text)
84
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
85
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
86
  except Exception as e:
@@ -171,6 +191,8 @@ with gr.Blocks() as demo:
171
  outputs=[status_output, results_table]
172
  )
173
 
 
 
174
  if __name__ == "__main__":
175
  print("\n" + "-"*30 + " App Starting " + "-"*30)
176
  # Check for SPACE_HOST and SPACE_ID at startup for information
@@ -193,4 +215,4 @@ if __name__ == "__main__":
193
  print("-"*(60 + len(" App Starting ")) + "\n")
194
 
195
  print("Launching Gradio Interface for Basic Agent Evaluation...")
196
- demo.launch(debug=True, share=False)
 
1
  import os
2
+ import time
3
  import gradio as gr
4
  import requests
 
5
  import pandas as pd
6
+ from agent import Agent
7
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
 
13
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
14
  class BasicAgent:
15
  def __init__(self):
16
+ # Initialize Agent
17
+ self.agent = Agent()
18
+
19
+ print("Agent initialized successfully")
20
+ def __call__(self, question: str, task_id: str, file_name: str | None = None) -> str:
21
  print(f"Agent received question (first 50 chars): {question[:50]}...")
22
+
23
+ max_retries = 10
24
+ base_sleep = 60
25
+
26
+ for attempt in range(max_retries):
27
+ try:
28
+ final_answer = self.agent.run(question=question, task_id=task_id, file_name=file_name)
29
+ print(f"Agent returning final answer: {final_answer}")
30
+ return final_answer
31
+ except Exception as e:
32
+ print(f"{str(e)}")
33
+ sleep_time = base_sleep * (attempt + 1) # Incremental sleep: 1s, 2s, 3s
34
+ if attempt < max_retries - 1:
35
+ print(f"Attempt {attempt + 1} failed. Retrying in {sleep_time} seconds...")
36
+ time.sleep(sleep_time)
37
+ continue
38
+ return f"Error processing query after {max_retries} attempts: {str(e)}"
39
 
40
  def run_and_submit_all( profile: gr.OAuthProfile | None):
41
  """
 
62
  except Exception as e:
63
  print(f"Error instantiating agent: {e}")
64
  return f"Error initializing agent: {e}", None
65
+ # In the case of an app running as a hugging Face space, this link points toward your codebase (useful for others so please keep it public)
66
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
67
  print(agent_code)
68
 
 
72
  response = requests.get(questions_url, timeout=15)
73
  response.raise_for_status()
74
  questions_data = response.json()
75
+ print(f"{questions_data}")
76
  if not questions_data:
77
  print("Fetched questions list is empty.")
78
  return "Fetched questions list is empty or invalid format.", None
 
95
  for item in questions_data:
96
  task_id = item.get("task_id")
97
  question_text = item.get("question")
98
+ file_name = item.get("file_name")
99
  if not task_id or question_text is None:
100
  print(f"Skipping item with missing task_id or question: {item}")
101
  continue
102
  try:
103
+ submitted_answer = agent(question=question_text, task_id=task_id, file_name=file_name)
104
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
105
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
106
  except Exception as e:
 
191
  outputs=[status_output, results_table]
192
  )
193
 
194
+
195
+
196
  if __name__ == "__main__":
197
  print("\n" + "-"*30 + " App Starting " + "-"*30)
198
  # Check for SPACE_HOST and SPACE_ID at startup for information
 
215
  print("-"*(60 + len(" App Starting ")) + "\n")
216
 
217
  print("Launching Gradio Interface for Basic Agent Evaluation...")
218
+ demo.launch(debug=True, share=False)
requirements.txt CHANGED
@@ -1,2 +1,19 @@
1
- gradio
2
- requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio~=5.33.2
2
+ requests~=2.32.4
3
+ itsdangerous
4
+ langchain~=0.3.24
5
+ langgraph~=0.3.34
6
+ pandas~=2.2.3
7
+ langchain-core~=0.3.56
8
+ langchain-google-genai~=2.1.3
9
+ langchain-community~=0.3.22
10
+ langchain-tavily
11
+ mediawikiapi~=1.3
12
+ wikipedia
13
+ pydantic~=2.11.3
14
+ beautifulsoup4~=4.13.4
15
+ openpyxl
16
+ protobuf~=5.29.4
17
+ genai~=2.1.0
18
+ transformers~=4.52.4
19
+ torch
wikipedia_tool.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tool for the Wikipedia API."""
2
+
3
+ from typing import Optional, Type
4
+ from langchain_core.callbacks import CallbackManagerForToolRun
5
+ from langchain_core.tools import BaseTool
6
+ from pydantic import BaseModel, Field
7
+ import pandas as pd
8
+ from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
9
+
10
+ class WikipediaQueryInput(BaseModel):
11
+ """Input for the WikipediaQuery tool."""
12
+
13
+ query: str = Field(description="query to look up on wikipedia")
14
+
15
+ class WikipediaTool(BaseTool): # type: ignore[override, override]
16
+ """Tool that searches the Wikipedia API."""
17
+
18
+ name: str = "wikipedia"
19
+ description: str = (
20
+ "A wrapper around Wikipedia. "
21
+ "Useful for when you need to answer general questions about "
22
+ "people, places, companies, facts, historical events, or other subjects. "
23
+ "Input should be a search query."
24
+ )
25
+ api_wrapper: WikipediaAPIWrapper
26
+
27
+ args_schema: Type[BaseModel] = WikipediaQueryInput
28
+
29
+ def _run(
30
+ self,
31
+ query: str,
32
+ run_manager: Optional[CallbackManagerForToolRun] = None,
33
+ ) -> str:
34
+ """Use the Wikipedia tool."""
35
+ pages = self.api_wrapper.load(query)
36
+
37
+ for page in pages:
38
+ try:
39
+ wikitables = pd.read_html(page.metadata["source"], attrs={"class": "wikitable"})
40
+ page.metadata["wikitable"] = "\n---\n".join(
41
+ f'{table}'
42
+ for table in wikitables
43
+ )
44
+ except:
45
+ continue
46
+
47
+ res = "\n---\n".join(
48
+ f'{page}'
49
+ for page in pages
50
+ )
51
+
52
+ return res