giulia-fontanella commited on
Commit
a741f9e
·
unverified ·
1 Parent(s): 7a40d3a

add presentation notebook

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. notebooks/presentation.ipynb +218 -0
  3. src/agent.py +36 -13
  4. src/tools.py +49 -5
.gitignore CHANGED
@@ -135,3 +135,5 @@ cython_debug/
135
 
136
  # configurations for VS Code
137
  .vscode
 
 
 
135
 
136
  # configurations for VS Code
137
  .vscode
138
+
139
+ data/
notebooks/presentation.ipynb ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "64b2c237",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "from langchain_openai import ChatOpenAI\n",
12
+ "from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace\n",
13
+ "\n",
14
+ "import sys\n",
15
+ "\n",
16
+ "sys.path.append(os.path.abspath(\"../src\"))\n",
17
+ "\n",
18
+ "from agent import SmartAgent"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "id": "8a1ece26",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "HUGGINGFACEHUB_API_TOKEN = os.getenv(\"HUGGINGFACEHUB_API_TOKEN\")\n",
29
+ "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
30
+ "TAVILY_API_KEY = os.getenv(\"TAVILY_API_KEY\")"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 3,
36
+ "id": "1d5bd941",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "MODEL_ID = \"gpt-4o\"\n",
41
+ "PROVIDER_TYPE = \"openai\" # \"openai\" or \"huggingface\""
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 4,
47
+ "id": "711e347b",
48
+ "metadata": {},
49
+ "outputs": [
50
+ {
51
+ "name": "stdout",
52
+ "output_type": "stream",
53
+ "text": [
54
+ "Agent initialized.\n",
55
+ "Telemetry initialized.\n"
56
+ ]
57
+ }
58
+ ],
59
+ "source": [
60
+ "# Instantiate Agent\n",
61
+ "try:\n",
62
+ " if PROVIDER_TYPE == \"huggingface\":\n",
63
+ " llm = HuggingFaceEndpoint(\n",
64
+ " repo_id=MODEL_ID,\n",
65
+ " huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,\n",
66
+ " )\n",
67
+ " chat = ChatHuggingFace(llm=llm, verbose=True)\n",
68
+ " elif PROVIDER_TYPE == \"openai\":\n",
69
+ " chat = ChatOpenAI(model=MODEL_ID, temperature=0.2)\n",
70
+ " else:\n",
71
+ " print(f\"Provider {PROVIDER_TYPE} not supported.\")\n",
72
+ "\n",
73
+ " agent = SmartAgent(chat)\n",
74
+ "\n",
75
+ "except Exception as e:\n",
76
+ " print(f\"Error instantiating agent: {e}\")"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 5,
82
+ "id": "57656b10",
83
+ "metadata": {},
84
+ "outputs": [
85
+ {
86
+ "name": "stdout",
87
+ "output_type": "stream",
88
+ "text": [
89
+ "Agent received question: The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places..\n",
90
+ "Provided file: ../data/sales.xlsx.\n",
91
+ "Agent returning answer: 29623.00\n"
92
+ ]
93
+ }
94
+ ],
95
+ "source": [
96
+ "# Run Agent\n",
97
+ "\n",
98
+ "question = \"The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.\"\n",
99
+ "filename = \"../data/sales.xlsx\"\n",
100
+ "answer = agent(question, filename)"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 6,
106
+ "id": "0f82fc20",
107
+ "metadata": {},
108
+ "outputs": [
109
+ {
110
+ "name": "stdout",
111
+ "output_type": "stream",
112
+ "text": [
113
+ "Agent received question: Here is a scanned invoice. Please extract the vendor name, invoice number, and total amount..\n",
114
+ "Provided file: ../data/invoice.png.\n",
115
+ "Agent returning answer: Adeline Palmerston, 01234, 440\n"
116
+ ]
117
+ }
118
+ ],
119
+ "source": [
120
+ "# Run Agent\n",
121
+ "\n",
122
+ "question = \"Here is a scanned invoice. Please extract the vendor name, invoice number, and total amount.\"\n",
123
+ "filename = \"../data/invoice.png\"\n",
124
+ "answer = agent(question, filename)"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": 11,
130
+ "id": "833cb42b",
131
+ "metadata": {},
132
+ "outputs": [
133
+ {
134
+ "name": "stdout",
135
+ "output_type": "stream",
136
+ "text": [
137
+ "Agent received question: Explain in detail what the conclusion of the paper Attention is all you need.\n",
138
+ "Agent returning answer: The conclusion of the paper \"Attention is All You Need\" is that the Transformer model, which relies entirely on attention mechanisms and does not use recurrence or convolution, achieves state-of-the-art results on translation tasks and is highly efficient in terms of parallelization, making it suitable for training on large datasets. The paper demonstrates that attention mechanisms alone are sufficient for achieving high performance in sequence transduction tasks.\n"
139
+ ]
140
+ }
141
+ ],
142
+ "source": [
143
+ "# Run Agent\n",
144
+ "\n",
145
+ "question = (\n",
146
+ " \"Explain in detail what the conclusion of the paper Attention is all you need\"\n",
147
+ ")\n",
148
+ "answer = agent(question)"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 8,
154
+ "id": "9077c2b0",
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "name": "stdout",
159
+ "output_type": "stream",
160
+ "text": [
161
+ "Agent received question: Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec. What does Teal say in response to the question Isn't that hot?.\n",
162
+ " \r"
163
+ ]
164
+ },
165
+ {
166
+ "name": "stderr",
167
+ "output_type": "stream",
168
+ "text": [
169
+ "/home/giulia/Progetti/Agent_Course_Final_Assignment/.venv/lib/python3.10/site-packages/whisper/transcribe.py:126: UserWarning: FP16 is not supported on CPU; using FP32 instead\n",
170
+ " warnings.warn(\"FP16 is not supported on CPU; using FP32 instead\")\n"
171
+ ]
172
+ },
173
+ {
174
+ "name": "stdout",
175
+ "output_type": "stream",
176
+ "text": [
177
+ "Agent returning answer: Extremely\n"
178
+ ]
179
+ }
180
+ ],
181
+ "source": [
182
+ "# Run Agent\n",
183
+ "\n",
184
+ "question = \"Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec. What does Teal say in response to the question Isn't that hot?\"\n",
185
+ "answer = agent(question)"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "id": "7ee42f12",
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": []
195
+ }
196
+ ],
197
+ "metadata": {
198
+ "kernelspec": {
199
+ "display_name": ".venv",
200
+ "language": "python",
201
+ "name": "python3"
202
+ },
203
+ "language_info": {
204
+ "codemirror_mode": {
205
+ "name": "ipython",
206
+ "version": 3
207
+ },
208
+ "file_extension": ".py",
209
+ "mimetype": "text/x-python",
210
+ "name": "python",
211
+ "nbconvert_exporter": "python",
212
+ "pygments_lexer": "ipython3",
213
+ "version": "3.10.12"
214
+ }
215
+ },
216
+ "nbformat": 4,
217
+ "nbformat_minor": 5
218
+ }
src/agent.py CHANGED
@@ -10,18 +10,38 @@ from langgraph.graph import START, StateGraph
10
  from langgraph.graph.message import add_messages
11
  from langgraph.prebuilt import ToolNode, tools_condition
12
 
13
- from .tools import (
14
- DescribeImage,
15
- ExtractTextFromImage,
16
- arxiv_search,
17
- download_youtube_video,
18
- extract_audio_from_video,
19
- read_excel,
20
- read_python,
21
- transcribe_audio,
22
- web_search,
23
- wiki_search,
24
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  class AgentState(TypedDict):
@@ -53,6 +73,9 @@ class SmartAgent:
53
  arxiv_search,
54
  download_youtube_video,
55
  extract_audio_from_video,
 
 
 
56
  ]
57
  self.chat_with_tools = chat.bind_tools(self.tools)
58
  self._initialize_graph()
@@ -91,7 +114,7 @@ class SmartAgent:
91
  self.langfuse_handler = CallbackHandler()
92
  print("Telemetry initialized.")
93
 
94
- def __call__(self, question: str, file_name: str) -> str:
95
  """Call the agent, passing system prompt and eventual file name."""
96
  sys_msg = SystemMessage(
97
  content="""You are a general AI assistant. You will be asked a factual question.
 
10
  from langgraph.graph.message import add_messages
11
  from langgraph.prebuilt import ToolNode, tools_condition
12
 
13
+ try:
14
+ from tools import (
15
+ DescribeImage,
16
+ ExtractTextFromImage,
17
+ arxiv_search,
18
+ download_youtube_video,
19
+ extract_audio_from_video,
20
+ read_excel,
21
+ read_python,
22
+ transcribe_audio,
23
+ web_search,
24
+ wiki_search,
25
+ add,
26
+ divide,
27
+ multiply,
28
+ )
29
+ except:
30
+ from .tools import (
31
+ DescribeImage,
32
+ ExtractTextFromImage,
33
+ arxiv_search,
34
+ download_youtube_video,
35
+ extract_audio_from_video,
36
+ read_excel,
37
+ read_python,
38
+ transcribe_audio,
39
+ web_search,
40
+ wiki_search,
41
+ add,
42
+ divide,
43
+ multiply,
44
+ )
45
 
46
 
47
  class AgentState(TypedDict):
 
73
  arxiv_search,
74
  download_youtube_video,
75
  extract_audio_from_video,
76
+ add,
77
+ divide,
78
+ multiply,
79
  ]
80
  self.chat_with_tools = chat.bind_tools(self.tools)
81
  self._initialize_graph()
 
114
  self.langfuse_handler = CallbackHandler()
115
  print("Telemetry initialized.")
116
 
117
+ def __call__(self, question: str, file_name: str | None = None) -> str:
118
  """Call the agent, passing system prompt and eventual file name."""
119
  sys_msg = SystemMessage(
120
  content="""You are a general AI assistant. You will be asked a factual question.
src/tools.py CHANGED
@@ -8,6 +8,11 @@ from langchain.tools import tool
8
  from langchain.tools.tavily_search import TavilySearchResults
9
  from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
10
  from langchain_core.messages import HumanMessage
 
 
 
 
 
11
 
12
 
13
  @tool
@@ -204,13 +209,14 @@ def download_youtube_video(youtube_url: str, output_path: str) -> str:
204
  Path to the saved video file.
205
  """
206
  ydl_opts = {
207
- "format": "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
208
  "outtmpl": output_path,
209
  "merge_output_format": "mp4",
210
  "quiet": True,
211
  }
212
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
213
- ydl.download([youtube_url])
 
214
  return output_path
215
 
216
 
@@ -277,10 +283,11 @@ def web_search(query: str) -> str:
277
 
278
  @tool
279
  def arxiv_search(query: str) -> str:
280
- """Search Arxiv for a query and return maximum 2 result.
281
 
282
  Args:
283
- query: The search query.
 
284
  """
285
  search_docs = ArxivLoader(query=query, load_max_docs=2).load()
286
  formatted_search_docs = "\n\n---\n\n".join(
@@ -297,3 +304,40 @@ def arxiv_search(query: str) -> str:
297
  ]
298
  )
299
  return {"arvix_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from langchain.tools.tavily_search import TavilySearchResults
9
  from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
10
  from langchain_core.messages import HumanMessage
11
+ from typing import List
12
+ from functools import reduce
13
+ import operator
14
+ import contextlib
15
+ import os
16
 
17
 
18
  @tool
 
209
  Path to the saved video file.
210
  """
211
  ydl_opts = {
212
+ "format": "bestvideo+bestaudio/best",
213
  "outtmpl": output_path,
214
  "merge_output_format": "mp4",
215
  "quiet": True,
216
  }
217
+ with contextlib.redirect_stderr(open(os.devnull, "w")):
218
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
219
+ ydl.download([youtube_url])
220
  return output_path
221
 
222
 
 
283
 
284
  @tool
285
  def arxiv_search(query: str) -> str:
286
+ """Search Arxiv for a paper.
287
 
288
  Args:
289
+ query: The search query to retrieve a specific paper, consisting
290
+ of title and/or authors name and/or year of publication.
291
  """
292
  search_docs = ArxivLoader(query=query, load_max_docs=2).load()
293
  formatted_search_docs = "\n\n---\n\n".join(
 
304
  ]
305
  )
306
  return {"arvix_results": formatted_search_docs}
307
+
308
+
309
+ @tool
310
+ def add(numbers: List[float]) -> float:
311
+ """Calculates the sum of a list of numbers.
312
+
313
+ Args:
314
+ numbers: A list of numeric values to be summed.
315
+
316
+ Returns:
317
+ The sum of all numbers in the list.
318
+ """
319
+ return sum(numbers)
320
+
321
+
322
+ @tool
323
+ def multiply(numbers: List[float]) -> float:
324
+ """Calculates the product of a list of numbers.
325
+
326
+ Args:
327
+ numbers: A list of numeric values to be multiplied.
328
+
329
+ Returns:
330
+ The product of all numbers in the list.
331
+ """
332
+ return reduce(operator.mul, numbers, 1.0)
333
+
334
+
335
+ @tool
336
+ def divide(a: int, b: int) -> float:
337
+ """Divide a and b.
338
+
339
+ Args:
340
+ a: first number
341
+ b: second number
342
+ """
343
+ return a / b