Leonardo commited on
Commit
8df2ba2
·
verified ·
1 Parent(s): cb66cbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -135
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import json
2
  import contextlib
3
- import mimetypes # Added missing import
4
  import os
5
- import re # Added missing import
6
- import shutil # Added missing import
7
  from typing import Dict, List, Optional, Any
8
 
9
  from dotenv import load_dotenv
@@ -22,106 +22,18 @@ from scripts.text_web_browser import (
22
  )
23
  from scripts.visual_qa import visualizer
24
 
25
- # from scripts.flux_lora_tool import FluxLoRATool
26
-
27
  from smolagents import (
28
  CodeAgent,
29
  HfApiModel,
30
  LiteLLMModel,
31
- OpenAIServerModel, # Added missing model
32
- TransformersModel, # Added missing model
33
  Tool,
34
  )
35
  from smolagents.agent_types import AgentText, AgentImage, AgentAudio
36
  from smolagents.gradio_ui import pull_messages_from_step, handle_agent_output_types
37
 
38
-
39
- class GoogleSearchTool(Tool):
40
- """Performs Google web searches using the Serper API."""
41
-
42
- name = "web_search"
43
- description = """Performs a google web search for your query then returns a string of the top search results."""
44
- inputs = {
45
- "query": {"type": "string", "description": "The search query to perform."},
46
- "filter_year": {
47
- "type": "integer",
48
- "description": "Optionally restrict results to a certain year",
49
- "nullable": True,
50
- },
51
- }
52
- output_type = "string"
53
-
54
- def __init__(self):
55
- """Initialize the tool with API key from environment."""
56
- super().__init__(self)
57
- self.serpapi_key = os.getenv("SERPER_API_KEY")
58
- self._validate_dependencies()
59
-
60
- def _validate_dependencies(self):
61
- """Ensure API key is available."""
62
- if not self.serpapi_key:
63
- raise ValueError(
64
- "Missing SerpAPI key. Make sure you have 'SERPER_API_KEY' in your env variables."
65
- )
66
-
67
- def forward(self, query: str, filter_year: Optional[int] = None) -> str:
68
- """Execute the search query and return formatted results."""
69
- import requests
70
-
71
- params = {
72
- "engine": "google",
73
- "q": query,
74
- "api_key": self.serpapi_key,
75
- "google_domain": "google.com",
76
- }
77
-
78
- headers = {"X-API-KEY": self.serpapi_key, "Content-Type": "application/json"}
79
-
80
- if filter_year is not None:
81
- params["tbs"] = (
82
- f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
83
- )
84
-
85
- response = requests.request(
86
- "POST",
87
- "https://google.serper.dev/search",
88
- headers=headers,
89
- data=json.dumps(params),
90
- )
91
-
92
- if response.status_code == 200:
93
- results = response.json()
94
- else:
95
- raise ValueError(response.json())
96
-
97
- if "organic" not in results.keys() or len(results["organic"]) == 0:
98
- year_filter_message = (
99
- f" with filter year={filter_year}" if filter_year is not None else ""
100
- )
101
- return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
102
-
103
- return self._format_search_results(results["organic"])
104
-
105
- def _format_search_results(self, organic_results: List[Dict[str, Any]]) -> str:
106
- """Format organic search results into a readable string."""
107
- web_snippets = []
108
-
109
- for idx, page in enumerate(organic_results):
110
- date_published = (
111
- f"\nDate published: {page['date']}" if "date" in page else ""
112
- )
113
- source = f"\nSource: {page['source']}" if "source" in page else ""
114
- snippet = f"\n{page['snippet']}" if "snippet" in page else ""
115
-
116
- formatted_result = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
117
- formatted_result = formatted_result.replace(
118
- "Your browser can't play this video.", ""
119
- )
120
- web_snippets.append(formatted_result)
121
-
122
- return "## Search Results\n" + "\n\n".join(web_snippets)
123
-
124
-
125
  # Constants and configurations
126
  AUTHORIZED_IMPORTS = [
127
  "requests",
@@ -148,16 +60,6 @@ AUTHORIZED_IMPORTS = [
148
  "csv",
149
  ]
150
 
151
-
152
- # Configuration setup
153
- def setup_environment():
154
- """Initialize environment variables and authentication."""
155
- load_dotenv(override=True)
156
- login(os.getenv("HF_TOKEN"))
157
- print("TOKKKK", os.getenv("HF_TOKEN")[-10:])
158
-
159
-
160
- # Browser configuration
161
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
162
  BROWSER_CONFIG = {
163
  "viewport_size": 1024 * 5,
@@ -169,10 +71,32 @@ BROWSER_CONFIG = {
169
  "serpapi_key": os.getenv("SERPAPI_API_KEY"),
170
  }
171
 
172
- # Custom role conversions for model response handling
173
  custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
 
176
  class ModelManager:
177
  """Manages model loading and initialization."""
178
 
@@ -211,7 +135,6 @@ class ModelManager:
211
  print(f"✗ Couldn't load model: {e}")
212
  raise
213
 
214
-
215
  class ToolRegistry:
216
  """Manages tool initialization and organization."""
217
 
@@ -255,8 +178,7 @@ class ToolRegistry:
255
  print(f"✗ Couldn't initialize image generation tool: {e}")
256
  raise
257
 
258
-
259
- # Agent creation in a factory function
260
  def create_agent():
261
  """Creates a fresh agent instance with properly configured tools."""
262
  # Initialize model
@@ -274,7 +196,7 @@ def create_agent():
274
  gemma_vision_tool = ToolRegistry.load_vision_tools()
275
 
276
  # Combine all tools into a single list (not a tuple)
277
- all_tools = [visualizer] + web_tools + [gemma_vision_tool]
278
 
279
  # Validate tools before creating agent
280
  for tool in all_tools:
@@ -292,21 +214,20 @@ def create_agent():
292
  planning_interval=4,
293
  )
294
 
295
-
296
  def stream_to_gradio(
297
  agent,
298
  task: str,
299
  reset_agent_memory: bool = False,
300
  additional_args: Optional[dict] = None,
301
  ):
302
- """Runs an agent with the given task and streams messages as gradio ChatMessages."""
303
  for step_log in agent.run(
304
  task, stream=True, reset=reset_agent_memory, additional_args=additional_args
305
  ):
306
  for message in pull_messages_from_step(step_log):
307
  yield message
308
 
309
- # Process final answer
310
  final_answer = step_log # Last log is the run's final_answer
311
  final_answer = handle_agent_output_types(final_answer)
312
 
@@ -318,19 +239,19 @@ def stream_to_gradio(
318
  elif isinstance(final_answer, AgentImage):
319
  yield gr.ChatMessage(
320
  role="assistant",
321
- content={"path": final_answer.to_string(), "mime_type": "image/png"},
322
- )
323
  elif isinstance(final_answer, AgentAudio):
324
  yield gr.ChatMessage(
325
  role="assistant",
326
- content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
327
- )
328
  else:
329
  yield gr.ChatMessage(
330
  role="assistant", content=f"**Final answer:** {str(final_answer)}"
331
  )
332
 
333
-
334
  class GradioUI:
335
  """A one-line interface to launch your agent in Gradio."""
336
 
@@ -344,6 +265,7 @@ class GradioUI:
344
 
345
  def interact_with_agent(self, prompt, messages, session_state):
346
  """Main interaction handler with the agent."""
 
347
  # Get or create session-specific agent
348
  if "agent" not in session_state:
349
  session_state["agent"] = create_agent()
@@ -363,8 +285,9 @@ class GradioUI:
363
  session_state["agent"], task=prompt, reset_agent_memory=False
364
  ):
365
  messages.append(msg)
366
- yield messages
367
- yield messages
 
368
  except Exception as e:
369
  print(f"Error in interaction: {str(e)}")
370
  raise
@@ -373,13 +296,6 @@ class GradioUI:
373
  self,
374
  file,
375
  file_uploads_log,
376
- allowed_file_types=[
377
- "application/pdf",
378
- "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
379
- "text/plain",
380
- "image/png", # Add PNG MIME type
381
- "video/mp4", # Add MP4 MIME type"
382
- ],
383
  ):
384
  """Handle file uploads with proper validation and security."""
385
  if file is None:
@@ -390,7 +306,7 @@ class GradioUI:
390
  except Exception as e:
391
  return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
392
 
393
- if mime_type not in allowed_file_types:
394
  return gr.Textbox("File type disallowed", visible=True), file_uploads_log
395
 
396
  # Sanitize file name
@@ -410,6 +326,13 @@ class GradioUI:
410
  extension = type_to_ext.get(mime_type, "")
411
  sanitized_name = "".join(name_parts) + extension
412
 
 
 
 
 
 
 
 
413
  # Save the uploaded file to the specified folder
414
  file_path = os.path.join(self.file_upload_folder, sanitized_name)
415
  shutil.copy(file.name, file_path)
@@ -423,14 +346,14 @@ class GradioUI:
423
  message = text_input
424
 
425
  if len(file_uploads_log) > 0:
426
- message += f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
427
 
428
  return (
429
  message,
430
  gr.Textbox(
431
  value="",
432
  interactive=False,
433
- placeholder="Please wait while Steps are getting populated",
434
  ),
435
  gr.Button(interactive=False),
436
  )
@@ -438,7 +361,7 @@ class GradioUI:
438
  def detect_device(self, request: gr.Request):
439
  """Detect whether the user is on mobile or desktop device."""
440
  if not request:
441
- return "Unknown device"
442
 
443
  # Method 1: Check sec-ch-ua-mobile header
444
  is_mobile_header = request.headers.get("sec-ch-ua-mobile")
@@ -477,7 +400,7 @@ class GradioUI:
477
  else:
478
  return self._create_mobile_layout()
479
 
480
- demo.launch(debug=True, **kwargs)
481
 
482
  def _create_desktop_layout(self):
483
  """Create the desktop layout with sidebar."""
@@ -512,7 +435,7 @@ class GradioUI:
512
  gr.HTML(
513
  """
514
  <div style="display: flex; align-items: center; gap: 8px; font-family: system-ui, -apple-system, sans-serif;">
515
- <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png"
516
  style="width: 32px; height: 32px; object-fit: contain;" alt="logo">
517
  <a target="_blank" href="https://github.com/huggingface/smolagents">
518
  <b>huggingface/smolagents</b>
@@ -652,7 +575,7 @@ class GradioUI:
652
  [text_input, launch_research_btn],
653
  )
654
 
655
-
656
  def main():
657
  """Main entry point for the application."""
658
  # Initialize environment
@@ -662,8 +585,7 @@ def main():
662
  os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
663
 
664
  # Launch UI
665
- GradioUI().launch()
666
-
667
 
668
  if __name__ == "__main__":
669
  main()
 
1
  import json
2
  import contextlib
3
+ import mimetypes
4
  import os
5
+ import re
6
+ import shutil
7
  from typing import Dict, List, Optional, Any
8
 
9
  from dotenv import load_dotenv
 
22
  )
23
  from scripts.visual_qa import visualizer
24
 
 
 
25
  from smolagents import (
26
  CodeAgent,
27
  HfApiModel,
28
  LiteLLMModel,
29
+ OpenAIServerModel,
30
+ TransformersModel,
31
  Tool,
32
  )
33
  from smolagents.agent_types import AgentText, AgentImage, AgentAudio
34
  from smolagents.gradio_ui import pull_messages_from_step, handle_agent_output_types
35
 
36
+ # ------------------------ Configuration and Setup ------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Constants and configurations
38
  AUTHORIZED_IMPORTS = [
39
  "requests",
 
60
  "csv",
61
  ]
62
 
 
 
 
 
 
 
 
 
 
 
63
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
64
  BROWSER_CONFIG = {
65
  "viewport_size": 1024 * 5,
 
71
  "serpapi_key": os.getenv("SERPAPI_API_KEY"),
72
  }
73
 
 
74
  custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
75
 
76
+ # Multimedia file types supported:
77
+ ALLOWED_FILE_TYPES = [
78
+ "application/pdf",
79
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
80
+ "text/plain",
81
+ "image/png",
82
+ "image/jpeg", # Added JPEG support
83
+ "image/gif", # Added GIF support
84
+ "video/mp4",
85
+ "audio/mpeg", # Added MP3 support
86
+ "audio/wav", # Added WAV support
87
+ "audio/ogg", # Added OGG support
88
+ ]
89
+
90
+ def setup_environment():
91
+ """Initialize environment variables and authentication."""
92
+ load_dotenv(override=True)
93
+ if os.getenv("HF_TOKEN"): # Check if token is actually set
94
+ login(os.getenv("HF_TOKEN"))
95
+ print("HF_TOKEN (last 10 characters):", os.getenv("HF_TOKEN")[-10:])
96
+ else:
97
+ print("HF_TOKEN not found in environment variables.")
98
 
99
+ # ------------------------ Model and Tool Management ------------------------
100
  class ModelManager:
101
  """Manages model loading and initialization."""
102
 
 
135
  print(f"✗ Couldn't load model: {e}")
136
  raise
137
 
 
138
  class ToolRegistry:
139
  """Manages tool initialization and organization."""
140
 
 
178
  print(f"✗ Couldn't initialize image generation tool: {e}")
179
  raise
180
 
181
+ # ------------------------ Agent Creation and Execution ------------------------
 
182
  def create_agent():
183
  """Creates a fresh agent instance with properly configured tools."""
184
  # Initialize model
 
196
  gemma_vision_tool = ToolRegistry.load_vision_tools()
197
 
198
  # Combine all tools into a single list (not a tuple)
199
+ all_tools = [visualizer] + web_tools + [gemma_vision_tool] + [image_generator]
200
 
201
  # Validate tools before creating agent
202
  for tool in all_tools:
 
214
  planning_interval=4,
215
  )
216
 
 
217
  def stream_to_gradio(
218
  agent,
219
  task: str,
220
  reset_agent_memory: bool = False,
221
  additional_args: Optional[dict] = None,
222
  ):
223
+ """Runs an agent with the given task and streams messages as Gradio ChatMessages."""
224
  for step_log in agent.run(
225
  task, stream=True, reset=reset_agent_memory, additional_args=additional_args
226
  ):
227
  for message in pull_messages_from_step(step_log):
228
  yield message
229
 
230
+ # Process final answer : Use a more comprehensive media output
231
  final_answer = step_log # Last log is the run's final_answer
232
  final_answer = handle_agent_output_types(final_answer)
233
 
 
239
  elif isinstance(final_answer, AgentImage):
240
  yield gr.ChatMessage(
241
  role="assistant",
242
+ content= { "image": final_answer.to_string(), "type": "file" },
243
+ ) # Send as Gradio-compatible file object:
244
  elif isinstance(final_answer, AgentAudio):
245
  yield gr.ChatMessage(
246
  role="assistant",
247
+ content={ "audio": final_answer.to_string(), "type": "file" },
248
+ ) # Send as Gradio-compatible file object
249
  else:
250
  yield gr.ChatMessage(
251
  role="assistant", content=f"**Final answer:** {str(final_answer)}"
252
  )
253
 
254
+ # ------------------------ Gradio UI Components ------------------------
255
  class GradioUI:
256
  """A one-line interface to launch your agent in Gradio."""
257
 
 
265
 
266
  def interact_with_agent(self, prompt, messages, session_state):
267
  """Main interaction handler with the agent."""
268
+
269
  # Get or create session-specific agent
270
  if "agent" not in session_state:
271
  session_state["agent"] = create_agent()
 
285
  session_state["agent"], task=prompt, reset_agent_memory=False
286
  ):
287
  messages.append(msg)
288
+ yield messages # Yield messages after each step
289
+ yield messages # Yield messages one last time
290
+
291
  except Exception as e:
292
  print(f"Error in interaction: {str(e)}")
293
  raise
 
296
  self,
297
  file,
298
  file_uploads_log,
 
 
 
 
 
 
 
299
  ):
300
  """Handle file uploads with proper validation and security."""
301
  if file is None:
 
306
  except Exception as e:
307
  return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
308
 
309
+ if mime_type not in ALLOWED_FILE_TYPES:
310
  return gr.Textbox("File type disallowed", visible=True), file_uploads_log
311
 
312
  # Sanitize file name
 
326
  extension = type_to_ext.get(mime_type, "")
327
  sanitized_name = "".join(name_parts) + extension
328
 
329
+ # Limit File Size, and Throw Error
330
+ max_file_size_mb = 50 # Define the limit
331
+ file_size_mb = os.path.getsize(file.name) / (1024 * 1024) # Size in MB
332
+
333
+ if file_size_mb > max_file_size_mb:
334
+ return gr.Textbox(f"File size exceeds {max_file_size_mb} MB limit.", visible=True), file_uploads_log
335
+
336
  # Save the uploaded file to the specified folder
337
  file_path = os.path.join(self.file_upload_folder, sanitized_name)
338
  shutil.copy(file.name, file_path)
 
346
  message = text_input
347
 
348
  if len(file_uploads_log) > 0:
349
+ message += f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" # Added file list
350
 
351
  return (
352
  message,
353
  gr.Textbox(
354
  value="",
355
  interactive=False,
356
+ placeholder="Processing...", # Changed placeholder.
357
  ),
358
  gr.Button(interactive=False),
359
  )
 
361
  def detect_device(self, request: gr.Request):
362
  """Detect whether the user is on mobile or desktop device."""
363
  if not request:
364
+ return "Unknown device" # Handle case where request is none.
365
 
366
  # Method 1: Check sec-ch-ua-mobile header
367
  is_mobile_header = request.headers.get("sec-ch-ua-mobile")
 
400
  else:
401
  return self._create_mobile_layout()
402
 
403
+ demo.queue(max_size=20).launch(debug=True, **kwargs) # Add queue with reasonable size
404
 
405
  def _create_desktop_layout(self):
406
  """Create the desktop layout with sidebar."""
 
435
  gr.HTML(
436
  """
437
  <div style="display: flex; align-items: center; gap: 8px; font-family: system-ui, -apple-system, sans-serif;">
438
+ <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png"
439
  style="width: 32px; height: 32px; object-fit: contain;" alt="logo">
440
  <a target="_blank" href="https://github.com/huggingface/smolagents">
441
  <b>huggingface/smolagents</b>
 
575
  [text_input, launch_research_btn],
576
  )
577
 
578
+ # ------------------------ Execution ------------------------
579
  def main():
580
  """Main entry point for the application."""
581
  # Initialize environment
 
585
  os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
586
 
587
  # Launch UI
588
+ GradioUI(file_upload_folder="uploaded_files").launch()
 
589
 
590
  if __name__ == "__main__":
591
  main()