zeerafle commited on
Commit
16c8f15
·
1 Parent(s): e0922af

Add YouTube and multimodal file input support

Browse files

The additions enable handling videos, images, PDFs and other files as
input, plus improved error handling and helper methods for MIME type
detection and file downloads.

Files changed (2) hide show
  1. agents/base_agent.py +5 -2
  2. app.py +169 -10
agents/base_agent.py CHANGED
@@ -8,20 +8,23 @@ from tools.code_execution import CodeExecutionTool
8
  # from tools.google_search import GoogleSearchTool
9
  from tools.web_search import tavily_search_tool
10
  from tools.wikipedia_search import wikipedia_search_tool
 
11
  import os
12
 
13
- SYSTEM_PROMPT = "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."
14
  model = get_default_model()
15
 
16
  code_execution_tool = CodeExecutionTool(api_key=os.environ['GOOGLE_API_KEY'])
17
  # google_search_tool = GoogleSearchTool(api_key=os.environ['GOOGLE_API_KEY'])
 
18
  tools = [
19
  arxiv_tool,
20
  wikipedia_search_tool,
21
  calculator_tool,
22
  code_execution_tool,
23
  # google_search_tool,
24
- tavily_search_tool
 
25
  ]
26
 
27
  agent_executor = create_react_agent(model, tools, prompt=SYSTEM_PROMPT)
 
8
  # from tools.google_search import GoogleSearchTool
9
  from tools.web_search import tavily_search_tool
10
  from tools.wikipedia_search import wikipedia_search_tool
11
+ from tools.youtube_understanding import YoutubeUnderstandingTool
12
  import os
13
 
14
+ SYSTEM_PROMPT = "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with only the answer to the question, nothing else. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."
15
  model = get_default_model()
16
 
17
  code_execution_tool = CodeExecutionTool(api_key=os.environ['GOOGLE_API_KEY'])
18
  # google_search_tool = GoogleSearchTool(api_key=os.environ['GOOGLE_API_KEY'])
19
+ youtube_understanding_tool = YoutubeUnderstandingTool(api_key=os.environ['GOOGLE_API_KEY'])
20
  tools = [
21
  arxiv_tool,
22
  wikipedia_search_tool,
23
  calculator_tool,
24
  code_execution_tool,
25
  # google_search_tool,
26
+ tavily_search_tool,
27
+ youtube_understanding_tool
28
  ]
29
 
30
  agent_executor = create_react_agent(model, tools, prompt=SYSTEM_PROMPT)
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import os
 
2
  import gradio as gr
3
  from langchain_core.messages import HumanMessage
4
  import requests
5
  import pandas as pd
6
  from agents.base_agent import agent_executor
 
 
7
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
@@ -12,14 +15,160 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
  # --- Basic Agent Definition ---
13
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
14
  class BasicAgent:
15
- def __init__(self):
 
16
  print("BasicAgent initialized.")
17
- def __call__(self, question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  print(f"Agent received question (first 50 chars): {question[:50]}...")
19
- messages = [HumanMessage(content=question)]
20
- response = agent_executor.invoke({"messages": messages})
21
- answer = response['messages'][-1].content
22
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def run_and_submit_all( profile: gr.OAuthProfile | None):
25
  """
@@ -74,7 +223,8 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
74
  # 3. Run your Agent
75
  results_log = []
76
  answers_payload = []
77
- print(f"Running agent on {len(questions_data)} questions...")
 
78
  for item in questions_data:
79
  task_id = item.get("task_id")
80
  question_text = item.get("question")
@@ -82,12 +232,21 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
82
  print(f"Skipping item with missing task_id or question: {item}")
83
  continue
84
  try:
85
- submitted_answer = agent(question_text)
 
86
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
87
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
 
 
 
 
88
  except Exception as e:
89
  print(f"Error running agent on task {task_id}: {e}")
90
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
 
 
 
 
91
 
92
  if not answers_payload:
93
  print("Agent did not produce any answers to submit.")
 
1
  import os
2
+ from typing import Optional, List, Dict, Any
3
  import gradio as gr
4
  from langchain_core.messages import HumanMessage
5
  import requests
6
  import pandas as pd
7
  from agents.base_agent import agent_executor
8
+ import mimetypes
9
+ import base64
10
 
11
  # (Keep Constants as is)
12
  # --- Constants ---
 
15
  # --- Basic Agent Definition ---
16
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
17
  class BasicAgent:
18
+ def __init__(self, api_url: str = DEFAULT_API_URL):
19
+ self.api_url = api_url
20
  print("BasicAgent initialized.")
21
+
22
+ def _get_mime_type(self, file_content: bytes, filename: str) -> str:
23
+ """Determine MIME type from file content and filename"""
24
+ # Try to guess from filename first
25
+ mime_type, _ = mimetypes.guess_type(filename)
26
+ if mime_type:
27
+ return mime_type
28
+
29
+ # Fallback: check file headers for common types
30
+ if file_content.startswith(b'\xff\xd8\xff'):
31
+ return 'image/jpeg'
32
+ elif file_content.startswith(b'\x89PNG\r\n\x1a\n'):
33
+ return 'image/png'
34
+ elif file_content.startswith(b'GIF8'):
35
+ return 'image/gif'
36
+ elif file_content.startswith(b'%PDF'):
37
+ return 'application/pdf'
38
+ elif file_content.startswith(b'RIFF') and b'WEBP' in file_content[:12]:
39
+ return 'image/webp'
40
+ else:
41
+ return 'application/octet-stream'
42
+
43
+ def _download_file(self, task_id: str) -> Optional[tuple]:
44
+ """Download task's associated file"""
45
+ try:
46
+ files_url = f"{self.api_url}/files/{task_id}"
47
+ print(f"Attempting to download file from {files_url}")
48
+
49
+ response = requests.get(files_url, timeout=30)
50
+ if response.status_code == 404:
51
+ print('File not found for task ID:', task_id)
52
+ return None
53
+
54
+ response.raise_for_status()
55
+
56
+ # try to get filename from Content-Disposition header
57
+ filename = "file"
58
+ if 'content-disposition' in response.headers:
59
+ content_disposition = response.headers['content-disposition']
60
+ if 'filename=' in content_disposition:
61
+ filename = content_disposition.split('filename=')[1].strip('"')
62
+
63
+ file_content = response.content
64
+ mime_type = self._get_mime_type(file_content, filename)
65
+
66
+ print(f"Downloaded file: {filename} ({len(file_content)} bytes, {mime_type})")
67
+ return file_content, filename, mime_type
68
+
69
+ except requests.exceptions.RequestException as e:
70
+ print(f"Error downloading file for task {task_id}: {e}")
71
+ return None
72
+ except Exception as e:
73
+ print(f"Unexpected error downloading file for task {task_id}: {e}")
74
+ return None
75
+
76
+
77
+ def _create_multimodal_content(self, question: str, task_id: str) -> List[Dict[str, Any]]:
78
+ """Create content blocks for multimodal input."""
79
+ content_blocks = [{"type": "text", "text": question}]
80
+
81
+ # Try to download associated file
82
+ file_data = self._download_file(task_id)
83
+ if file_data:
84
+ file_content, filename, mime_type = file_data
85
+
86
+ # Convert file content to base64
87
+ base64_content = base64.b64encode(file_content).decode('utf-8')
88
+
89
+ # Create appropriate content block based on file type
90
+ if mime_type.startswith('image/'):
91
+ content_blocks.append({
92
+ "type": "image",
93
+ "source_type": "base64",
94
+ "data": base64_content,
95
+ "mime_type": mime_type
96
+ })
97
+ print(f"Added image content block: {filename}")
98
+
99
+ elif mime_type == 'application/pdf':
100
+ content_blocks.append({
101
+ "type": "file",
102
+ "source_type": "base64",
103
+ "data": base64_content,
104
+ "mime_type": mime_type
105
+ })
106
+ print(f"Added PDF content block: {filename}")
107
+
108
+ elif mime_type.startswith('audio/'):
109
+ content_blocks.append({
110
+ "type": "audio",
111
+ "source_type": "base64",
112
+ "data": base64_content,
113
+ "mime_type": mime_type
114
+ })
115
+ print(f"Added audio content block: {filename}")
116
+
117
+ elif mime_type.startswith('video/'):
118
+ content_blocks.append({
119
+ "type": "video",
120
+ "source_type": "base64",
121
+ "data": base64_content,
122
+ "mime_type": mime_type
123
+ })
124
+ print(f"Added video content block: {filename}")
125
+
126
+ else:
127
+ # For other file types, add as generic file
128
+ content_blocks.append({
129
+ "type": "file",
130
+ "source_type": "base64",
131
+ "data": base64_content,
132
+ "mime_type": mime_type
133
+ })
134
+ print(f"Added generic file content block: {filename} ({mime_type})")
135
+
136
+ # Add context about the file to the text prompt
137
+ content_blocks[0]["text"] += f"\n\nNote: I have attached a file named '{filename}' of type '{mime_type}'. Please analyze this file in the context of the question above."
138
+
139
+ return content_blocks
140
+
141
+ def __call__(self, question: str, task_id: str = "") -> str:
142
  print(f"Agent received question (first 50 chars): {question[:50]}...")
143
+ if task_id:
144
+ print(f"Processing task_id: {task_id}")
145
+
146
+ try:
147
+ # Create multimodal content if task_id is provided
148
+ if task_id:
149
+ content = self._create_multimodal_content(question, task_id)
150
+ message = HumanMessage(content=content)
151
+ else:
152
+ # Fallback to text-only
153
+ message = HumanMessage(content=question)
154
+
155
+ # Invoke the agent
156
+ response = agent_executor.invoke({"messages": [message]})
157
+ answer = response['messages'][-1].content
158
+
159
+ return answer
160
+
161
+ except Exception as e:
162
+ print(f"Error in agent execution: {e}")
163
+ # Fallback to text-only if multimodal fails
164
+ try:
165
+ message = HumanMessage(content=question)
166
+ response = agent_executor.invoke({"messages": [message]})
167
+ answer = response['messages'][-1].content
168
+ return answer
169
+ except Exception as fallback_error:
170
+ print(f"Fallback also failed: {fallback_error}")
171
+ return f"Error processing question: {e}"
172
 
173
  def run_and_submit_all( profile: gr.OAuthProfile | None):
174
  """
 
223
  # 3. Run your Agent
224
  results_log = []
225
  answers_payload = []
226
+ print(f"Running agent with multimodal support on {len(questions_data)} questions...")
227
+
228
  for item in questions_data:
229
  task_id = item.get("task_id")
230
  question_text = item.get("question")
 
232
  print(f"Skipping item with missing task_id or question: {item}")
233
  continue
234
  try:
235
+ # Pass both question and task_id to enable multimodal processing
236
+ submitted_answer = agent(question_text, task_id)
237
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
238
+ results_log.append({
239
+ "Task ID": task_id,
240
+ "Question": question_text,
241
+ "Submitted Answer": submitted_answer
242
+ })
243
  except Exception as e:
244
  print(f"Error running agent on task {task_id}: {e}")
245
+ results_log.append({
246
+ "Task ID": task_id,
247
+ "Question": question_text,
248
+ "Submitted Answer": f"AGENT ERROR: {e}"
249
+ })
250
 
251
  if not answers_payload:
252
  print("Agent did not produce any answers to submit.")