Leonardo commited on
Commit
baae88b
·
verified ·
1 Parent(s): c502207

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +416 -237
app.py CHANGED
@@ -1,11 +1,14 @@
1
  import argparse
2
  import json
 
3
  import os
 
 
4
  import threading
5
  from concurrent.futures import ThreadPoolExecutor, as_completed
6
  from datetime import datetime
7
  from pathlib import Path
8
- from typing import List, Optional
9
 
10
  import datasets
11
  import pandas as pd
@@ -29,6 +32,7 @@ from scripts.text_web_browser import (
29
  VisitTool,
30
  )
31
  from scripts.visual_qa import visualizer
 
32
  # from scripts.flux_lora_tool import FluxLoRATool
33
  from tqdm import tqdm
34
 
@@ -37,15 +41,18 @@ from smolagents import (
37
  HfApiModel,
38
  LiteLLMModel,
39
  Model,
 
 
40
  ToolCallingAgent,
 
41
  )
42
  from smolagents.agent_types import AgentText, AgentImage, AgentAudio
43
  from smolagents.gradio_ui import pull_messages_from_step, handle_agent_output_types
44
 
45
- from smolagents import Tool
46
-
47
 
48
  class GoogleSearchTool(Tool):
 
 
49
  name = "web_search"
50
  description = """Performs a google web search for your query then returns a string of the top search results."""
51
  inputs = {
@@ -59,17 +66,22 @@ class GoogleSearchTool(Tool):
59
  output_type = "string"
60
 
61
  def __init__(self):
 
62
  super().__init__(self)
63
- import os
64
-
65
  self.serpapi_key = os.getenv("SERPER_API_KEY")
 
 
 
 
 
 
 
 
66
 
67
  def forward(self, query: str, filter_year: Optional[int] = None) -> str:
 
68
  import requests
69
 
70
- if self.serpapi_key is None:
71
- raise ValueError("Missing SerpAPI key. Make sure you have 'SERPER_API_KEY' in your env variables.")
72
-
73
  params = {
74
  "engine": "google",
75
  "q": query,
@@ -77,60 +89,54 @@ class GoogleSearchTool(Tool):
77
  "google_domain": "google.com",
78
  }
79
 
80
- headers = {
81
- 'X-API-KEY': self.serpapi_key,
82
- 'Content-Type': 'application/json'
83
- }
84
 
85
  if filter_year is not None:
86
- params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
87
-
88
- response = requests.request("POST", "https://google.serper.dev/search", headers=headers, data=json.dumps(params))
89
-
 
 
 
 
 
 
90
 
91
  if response.status_code == 200:
92
  results = response.json()
93
  else:
94
  raise ValueError(response.json())
95
 
96
- if "organic" not in results.keys():
97
- print("REZZZ", results.keys())
98
- if filter_year is not None:
99
- raise Exception(
100
- f"No results found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
101
- )
102
- else:
103
- raise Exception(f"No results found for query: '{query}'. Use a less restrictive query.")
104
- if len(results["organic"]) == 0:
105
- year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
106
  return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
107
 
108
- web_snippets = []
109
- if "organic" in results:
110
- for idx, page in enumerate(results["organic"]):
111
- date_published = ""
112
- if "date" in page:
113
- date_published = "\nDate published: " + page["date"]
114
-
115
- source = ""
116
- if "source" in page:
117
- source = "\nSource: " + page["source"]
118
 
119
- snippet = ""
120
- if "snippet" in page:
121
- snippet = "\n" + page["snippet"]
122
 
123
- redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
 
 
 
 
 
124
 
125
- redacted_version = redacted_version.replace("Your browser can't play this video.", "")
126
- web_snippets.append(redacted_version)
 
 
 
127
 
128
  return "## Search Results\n" + "\n\n".join(web_snippets)
129
 
130
- # web_search = GoogleSearchTool()
131
 
132
- # print(web_search(query="Donald Trump news"))
133
- # quit()
134
  AUTHORIZED_IMPORTS = [
135
  "requests",
136
  "zipfile",
@@ -155,18 +161,18 @@ AUTHORIZED_IMPORTS = [
155
  "fractions",
156
  "csv",
157
  ]
158
- load_dotenv(override=True)
159
- login(os.getenv("HF_TOKEN"))
160
-
161
 
162
- print("TOKKKK", os.getenv("HF_TOKEN")[-10:])
163
 
164
- append_answer_lock = threading.Lock()
 
 
 
 
 
165
 
166
- custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
167
 
 
168
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
169
-
170
  BROWSER_CONFIG = {
171
  "viewport_size": 1024 * 5,
172
  "downloads_folder": "downloads_folder",
@@ -177,43 +183,129 @@ BROWSER_CONFIG = {
177
  "serpapi_key": os.getenv("SERPAPI_API_KEY"),
178
  }
179
 
180
- os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
 
181
 
182
- model = HfApiModel(
183
- custom_role_conversions=custom_role_conversions,
184
- )
185
 
186
- text_limit = 20000
187
- ti_tool = TextInspectorTool(model, text_limit)
 
 
 
 
 
 
 
188
 
189
- browser = SimpleTextBrowser(**BROWSER_CONFIG)
 
190
 
191
- WEB_TOOLS = [
192
- GoogleSearchTool(),
193
- VisitTool(browser),
194
- PageUpTool(browser),
195
- PageDownTool(browser),
196
- FinderTool(browser),
197
- FindNextTool(browser),
198
- ArchiveSearchTool(browser),
199
- TextInspectorTool(model, text_limit),
200
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- # flux_tool = FluxLoRATool()
203
 
204
  # Agent creation in a factory function
205
  def create_agent():
206
- """Creates a fresh agent instance for each session"""
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  return CodeAgent(
208
  model=model,
209
- tools=[visualizer] + WEB_TOOLS,
210
  max_steps=10,
211
  verbosity_level=1,
212
  additional_authorized_imports=AUTHORIZED_IMPORTS,
213
  planning_interval=4,
214
  )
215
 
216
- document_inspection_tool = TextInspectorTool(model, 20000)
217
 
218
  def stream_to_gradio(
219
  agent,
@@ -221,13 +313,14 @@ def stream_to_gradio(
221
  reset_agent_memory: bool = False,
222
  additional_args: Optional[dict] = None,
223
  ):
224
- """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
225
- for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
226
- for message in pull_messages_from_step(
227
- step_log,
228
- ):
229
  yield message
230
 
 
231
  final_answer = step_log # Last log is the run's final_answer
232
  final_answer = handle_agent_output_types(final_answer)
233
 
@@ -247,36 +340,42 @@ def stream_to_gradio(
247
  content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
248
  )
249
  else:
250
- yield gr.ChatMessage(role="assistant", content=f"**Final answer:** {str(final_answer)}")
 
 
251
 
252
 
253
  class GradioUI:
254
- """A one-line interface to launch your agent in Gradio"""
255
 
256
  def __init__(self, file_upload_folder: str | None = None):
257
-
258
  self.file_upload_folder = file_upload_folder
 
259
  if self.file_upload_folder is not None:
260
  if not os.path.exists(file_upload_folder):
261
  os.mkdir(file_upload_folder)
262
 
263
  def interact_with_agent(self, prompt, messages, session_state):
 
264
  # Get or create session-specific agent
265
- if 'agent' not in session_state:
266
- session_state['agent'] = create_agent()
267
 
268
  # Adding monitoring
269
  try:
270
- # log the existence of agent memory
271
- has_memory = hasattr(session_state['agent'], 'memory')
272
  print(f"Agent has memory: {has_memory}")
273
  if has_memory:
274
  print(f"Memory type: {type(session_state['agent'].memory)}")
275
-
276
  messages.append(gr.ChatMessage(role="user", content=prompt))
277
  yield messages
278
-
279
- for msg in stream_to_gradio(session_state['agent'], task=prompt, reset_agent_memory=False):
 
 
280
  messages.append(msg)
281
  yield messages
282
  yield messages
@@ -294,9 +393,7 @@ class GradioUI:
294
  "text/plain",
295
  ],
296
  ):
297
- """
298
- Handle file uploads, default allowed types are .pdf, .docx, and .txt
299
- """
300
  if file is None:
301
  return gr.Textbox("No file uploaded", visible=True), file_uploads_log
302
 
@@ -312,191 +409,273 @@ class GradioUI:
312
  original_name = os.path.basename(file.name)
313
  sanitized_name = re.sub(
314
  r"[^\w\-.]", "_", original_name
315
- ) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
316
 
 
317
  type_to_ext = {}
318
  for ext, t in mimetypes.types_map.items():
319
  if t not in type_to_ext:
320
  type_to_ext[t] = ext
321
 
322
- # Ensure the extension correlates to the mime type
323
- sanitized_name = sanitized_name.split(".")[:-1]
324
- sanitized_name.append("" + type_to_ext[mime_type])
325
- sanitized_name = "".join(sanitized_name)
326
 
327
  # Save the uploaded file to the specified folder
328
- file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
329
  shutil.copy(file.name, file_path)
330
 
331
- return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
 
 
332
 
333
  def log_user_message(self, text_input, file_uploads_log):
 
 
 
 
 
 
334
  return (
335
- text_input
336
- + (
337
- f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
338
- if len(file_uploads_log) > 0
339
- else ""
340
  ),
341
- gr.Textbox(value="", interactive=False, placeholder="Please wait while Steps are getting populated"),
342
- gr.Button(interactive=False)
343
  )
344
 
345
  def detect_device(self, request: gr.Request):
346
- # Check whether the user device is a mobile or a computer
347
-
348
  if not request:
349
  return "Unknown device"
 
350
  # Method 1: Check sec-ch-ua-mobile header
351
- is_mobile_header = request.headers.get('sec-ch-ua-mobile')
352
  if is_mobile_header:
353
- return "Mobile" if '?1' in is_mobile_header else "Desktop"
354
-
355
  # Method 2: Check user-agent string
356
- user_agent = request.headers.get('user-agent', '').lower()
357
- mobile_keywords = ['android', 'iphone', 'ipad', 'mobile', 'phone']
358
-
359
  if any(keyword in user_agent for keyword in mobile_keywords):
360
  return "Mobile"
361
-
362
  # Method 3: Check platform
363
- platform = request.headers.get('sec-ch-ua-platform', '').lower()
364
  if platform:
365
  if platform in ['"android"', '"ios"']:
366
  return "Mobile"
367
  elif platform in ['"windows"', '"macos"', '"linux"']:
368
  return "Desktop"
369
-
370
  # Default case if no clear indicators
371
- return "Desktop"
372
-
373
- def launch(self, **kwargs):
374
 
 
 
375
  with gr.Blocks(theme="ocean", fill_height=True) as demo:
376
  # Different layouts for mobile and computer devices
377
  @gr.render()
378
  def layout(request: gr.Request):
379
  device = self.detect_device(request)
380
  print(f"device - {device}")
381
- # Render layout with sidebar
382
  if device == "Desktop":
383
- with gr.Blocks(fill_height=True,) as sidebar_demo:
384
- with gr.Sidebar():
385
- gr.Markdown("""#OpenDeepResearch - free the AI agents!""")
386
- with gr.Group():
387
- gr.Markdown("**What's on your mind mate?**", container=True)
388
- text_input = gr.Textbox(lines=3, label="Your request", container=False, placeholder="Enter your prompt here and press Shift+Enter or press the button")
389
- launch_research_btn = gr.Button("Run", variant="primary")
390
-
391
- # If an upload folder is provided, enable the upload feature
392
- if self.file_upload_folder is not None:
393
- upload_file = gr.File(label="Upload a file")
394
- upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
395
- upload_file.change(
396
- self.upload_file,
397
- [upload_file, file_uploads_log],
398
- [upload_status, file_uploads_log],
399
- )
400
-
401
- gr.HTML("<br><br><h4><center>Powered by:</center></h4>")
402
- with gr.Row():
403
- gr.HTML("""<div style="display: flex; align-items: center; gap: 8px; font-family: system-ui, -apple-system, sans-serif;">
404
- <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png" style="width: 32px; height: 32px; object-fit: contain;" alt="logo">
405
- <a target="_blank" href="https://github.com/huggingface/smolagents"><b>huggingface/smolagents</b></a>
406
- </div>""")
407
-
408
- # Add session state to store session-specific data
409
- session_state = gr.State({}) # Initialize empty state for each session
410
- stored_messages = gr.State([])
411
- file_uploads_log = gr.State([])
412
- chatbot = gr.Chatbot(
413
- label="open-Deep-Research",
414
- type="messages",
415
- avatar_images=(
416
- None,
417
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
418
- ),
419
- resizeable=False,
420
- scale=1,
421
- elem_id="my-chatbot"
422
- )
423
-
424
- text_input.submit(
425
- self.log_user_message,
426
- [text_input, file_uploads_log],
427
- [stored_messages, text_input, launch_research_btn],
428
- ).then(self.interact_with_agent,
429
- # Include session_state in function calls
430
- [stored_messages, chatbot, session_state],
431
- [chatbot]
432
- ).then(lambda : (gr.Textbox(interactive=True, placeholder="Enter your prompt here and press the button"), gr.Button(interactive=True)),
433
- None,
434
- [text_input, launch_research_btn])
435
- launch_research_btn.click(
436
- self.log_user_message,
437
- [text_input, file_uploads_log],
438
- [stored_messages, text_input, launch_research_btn],
439
- ).then(self.interact_with_agent,
440
- # Include session_state in function calls
441
- [stored_messages, chatbot, session_state],
442
- [chatbot]
443
- ).then(lambda : (gr.Textbox(interactive=True, placeholder="Enter your prompt here and press the button"), gr.Button(interactive=True)),
444
- None,
445
- [text_input, launch_research_btn])
446
-
447
- # Render simple layout
448
  else:
449
- with gr.Blocks(fill_height=True,) as simple_demo:
450
- gr.Markdown("""#OpenDeepResearch - free the AI agents!""")
451
- # Add session state to store session-specific data
452
- session_state = gr.State({}) # Initialize empty state for each session
453
- stored_messages = gr.State([])
454
- file_uploads_log = gr.State([])
455
- chatbot = gr.Chatbot(
456
- label="open-Deep-Research",
457
- type="messages",
458
- avatar_images=(
459
- None,
460
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
461
- ),
462
- resizeable=True,
463
- scale=1,
464
- )
465
- # If an upload folder is provided, enable the upload feature
466
- if self.file_upload_folder is not None:
467
- upload_file = gr.File(label="Upload a file")
468
- upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
469
- upload_file.change(
470
- self.upload_file,
471
- [upload_file, file_uploads_log],
472
- [upload_status, file_uploads_log],
473
- )
474
- text_input = gr.Textbox(lines=1, label="What's on your mind mate?", placeholder="Chuck in a question and we'll take care of the rest")
475
- launch_research_btn = gr.Button("Run", variant="primary",)
476
-
477
- text_input.submit(
478
- self.log_user_message,
479
- [text_input, file_uploads_log],
480
- [stored_messages, text_input, launch_research_btn],
481
- ).then(self.interact_with_agent,
482
- # Include session_state in function calls
483
- [stored_messages, chatbot, session_state],
484
- [chatbot]
485
- ).then(lambda : (gr.Textbox(interactive=True, placeholder="Chuck in a question and we'll take care of the rest"), gr.Button(interactive=True)),
486
- None,
487
- [text_input, launch_research_btn])
488
- launch_research_btn.click(
489
- self.log_user_message,
490
- [text_input, file_uploads_log],
491
- [stored_messages, text_input, launch_research_btn],
492
- ).then(self.interact_with_agent,
493
- # Include session_state in function calls
494
- [stored_messages, chatbot, session_state],
495
- [chatbot]
496
- ).then(lambda : (gr.Textbox(interactive=True, placeholder="Chuck in a question and we'll take care of the rest"), gr.Button(interactive=True)),
497
- None,
498
- [text_input, launch_research_btn])
499
-
500
  demo.launch(debug=True, **kwargs)
501
 
502
- GradioUI().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  import json
3
+ import mimetypes # Added missing import
4
  import os
5
+ import re # Added missing import
6
+ import shutil # Added missing import
7
  import threading
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  from datetime import datetime
10
  from pathlib import Path
11
+ from typing import Dict, List, Optional, Any
12
 
13
  import datasets
14
  import pandas as pd
 
32
  VisitTool,
33
  )
34
  from scripts.visual_qa import visualizer
35
+
36
  # from scripts.flux_lora_tool import FluxLoRATool
37
  from tqdm import tqdm
38
 
 
41
  HfApiModel,
42
  LiteLLMModel,
43
  Model,
44
+ OpenAIServerModel, # Added missing model
45
+ TransformersModel, # Added missing model
46
  ToolCallingAgent,
47
+ Tool,
48
  )
49
  from smolagents.agent_types import AgentText, AgentImage, AgentAudio
50
  from smolagents.gradio_ui import pull_messages_from_step, handle_agent_output_types
51
 
 
 
52
 
53
  class GoogleSearchTool(Tool):
54
+ """Performs Google web searches using the Serper API."""
55
+
56
  name = "web_search"
57
  description = """Performs a google web search for your query then returns a string of the top search results."""
58
  inputs = {
 
66
  output_type = "string"
67
 
68
  def __init__(self):
69
+ """Initialize the tool with API key from environment."""
70
  super().__init__(self)
 
 
71
  self.serpapi_key = os.getenv("SERPER_API_KEY")
72
+ self._validate_dependencies()
73
+
74
+ def _validate_dependencies(self):
75
+ """Ensure API key is available."""
76
+ if not self.serpapi_key:
77
+ raise ValueError(
78
+ "Missing SerpAPI key. Make sure you have 'SERPER_API_KEY' in your env variables."
79
+ )
80
 
81
  def forward(self, query: str, filter_year: Optional[int] = None) -> str:
82
+ """Execute the search query and return formatted results."""
83
  import requests
84
 
 
 
 
85
  params = {
86
  "engine": "google",
87
  "q": query,
 
89
  "google_domain": "google.com",
90
  }
91
 
92
+ headers = {"X-API-KEY": self.serpapi_key, "Content-Type": "application/json"}
 
 
 
93
 
94
  if filter_year is not None:
95
+ params["tbs"] = (
96
+ f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
97
+ )
98
+
99
+ response = requests.request(
100
+ "POST",
101
+ "https://google.serper.dev/search",
102
+ headers=headers,
103
+ data=json.dumps(params),
104
+ )
105
 
106
  if response.status_code == 200:
107
  results = response.json()
108
  else:
109
  raise ValueError(response.json())
110
 
111
+ if "organic" not in results.keys() or len(results["organic"]) == 0:
112
+ year_filter_message = (
113
+ f" with filter year={filter_year}" if filter_year is not None else ""
114
+ )
 
 
 
 
 
 
115
  return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
116
 
117
+ return self._format_search_results(results["organic"])
 
 
 
 
 
 
 
 
 
118
 
119
+ def _format_search_results(self, organic_results: List[Dict[str, Any]]) -> str:
120
+ """Format organic search results into a readable string."""
121
+ web_snippets = []
122
 
123
+ for idx, page in enumerate(organic_results):
124
+ date_published = (
125
+ f"\nDate published: {page['date']}" if "date" in page else ""
126
+ )
127
+ source = f"\nSource: {page['source']}" if "source" in page else ""
128
+ snippet = f"\n{page['snippet']}" if "snippet" in page else ""
129
 
130
+ formatted_result = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
131
+ formatted_result = formatted_result.replace(
132
+ "Your browser can't play this video.", ""
133
+ )
134
+ web_snippets.append(formatted_result)
135
 
136
  return "## Search Results\n" + "\n\n".join(web_snippets)
137
 
 
138
 
139
+ # Constants and configurations
 
140
  AUTHORIZED_IMPORTS = [
141
  "requests",
142
  "zipfile",
 
161
  "fractions",
162
  "csv",
163
  ]
 
 
 
164
 
 
165
 
166
+ # Configuration setup
167
+ def setup_environment():
168
+ """Initialize environment variables and authentication."""
169
+ load_dotenv(override=True)
170
+ login(os.getenv("HF_TOKEN"))
171
+ print("TOKKKK", os.getenv("HF_TOKEN")[-10:])
172
 
 
173
 
174
+ # Browser configuration
175
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
 
176
  BROWSER_CONFIG = {
177
  "viewport_size": 1024 * 5,
178
  "downloads_folder": "downloads_folder",
 
183
  "serpapi_key": os.getenv("SERPAPI_API_KEY"),
184
  }
185
 
186
+ # Custom role conversions for model response handling
187
+ custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
188
 
 
 
 
189
 
190
+ class ModelManager:
191
+ """Manages model loading and initialization."""
192
+
193
+ @staticmethod
194
+ def load_model(chosen_inference: str, model_id: str, key_manager=None):
195
+ """Load the specified model with appropriate configuration."""
196
+ try:
197
+ if chosen_inference == "hf_api":
198
+ return HfApiModel(model_id=model_id)
199
 
200
+ elif chosen_inference == "hf_api_provider":
201
+ return HfApiModel(provider="together")
202
 
203
+ elif chosen_inference == "litellm":
204
+ return LiteLLMModel(model_id=model_id)
205
+
206
+ elif chosen_inference == "ollama":
207
+ if not key_manager:
208
+ raise ValueError("Key manager required for Ollama model")
209
+
210
+ return LiteLLMModel(
211
+ model_id=model_id,
212
+ api_base="http://localhost:11434",
213
+ api_key=key_manager.get_key("ollama_api_key"),
214
+ num_ctx=8192,
215
+ )
216
+
217
+ elif chosen_inference == "openai":
218
+ if not key_manager:
219
+ raise ValueError("Key manager required for OpenAI model")
220
+
221
+ return OpenAIServerModel(
222
+ model_id=model_id, api_key=key_manager.get_key("openai_api_key")
223
+ )
224
+
225
+ elif chosen_inference == "transformers":
226
+ return TransformersModel(
227
+ model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
228
+ device_map="auto",
229
+ max_new_tokens=1000,
230
+ )
231
+
232
+ else:
233
+ raise ValueError(f"Invalid inference type: {chosen_inference}")
234
+
235
+ except Exception as e:
236
+ print(f"✗ Couldn't load model: {e}")
237
+ raise
238
+
239
+
240
+ class ToolRegistry:
241
+ """Manages tool initialization and organization."""
242
+
243
+ @staticmethod
244
+ def load_web_tools(model, browser, text_limit=20000):
245
+ """Initialize and return web-related tools."""
246
+ return [
247
+ GoogleSearchTool(),
248
+ VisitTool(browser),
249
+ PageUpTool(browser),
250
+ PageDownTool(browser),
251
+ FinderTool(browser),
252
+ FindNextTool(browser),
253
+ ArchiveSearchTool(browser),
254
+ TextInspectorTool(model, text_limit),
255
+ ]
256
+
257
+ @staticmethod
258
+ def load_vision_tools():
259
+ """Initialize and return vision-related tools."""
260
+ try:
261
+ return Tool.from_space(
262
+ space_id="xkerser/gemma-3-12b-it",
263
+ name="gemma_vision",
264
+ description="Upload an image to extract and analyze text and visual content from images using Gemma 3",
265
+ )
266
+ except Exception as e:
267
+ print(f"✗ Couldn't initialize vision tool: {e}")
268
+ raise
269
+
270
+ @staticmethod
271
+ def load_image_generation_tools():
272
+ """Initialize and return image generation tools."""
273
+ try:
274
+ return Tool.from_space(
275
+ space_id="xkerser/FLUX.1-dev",
276
+ name="image_generator",
277
+ description="Generates high-quality images using the FLUX.1-dev model based on text prompts.",
278
+ )
279
+ except Exception as e:
280
+ print(f"✗ Couldn't initialize image generation tool: {e}")
281
+ raise
282
 
 
283
 
284
  # Agent creation in a factory function
285
  def create_agent():
286
+ """Creates a fresh agent instance for each session."""
287
+ # Initialize model
288
+ model = LiteLLMModel(
289
+ custom_role_conversions=custom_role_conversions,
290
+ model_id="openrouter/perplexity/r1-1776",
291
+ )
292
+
293
+ # Initialize tools
294
+ text_limit = 20000
295
+ browser = SimpleTextBrowser(**BROWSER_CONFIG)
296
+
297
+ web_tools = ToolRegistry.load_web_tools(model, browser, text_limit)
298
+ gemma_vision_tool = ToolRegistry.load_vision_tools()
299
+
300
  return CodeAgent(
301
  model=model,
302
+ tools=([visualizer] + web_tools, gemma_vision_tool), # Fixed the missing comma
303
  max_steps=10,
304
  verbosity_level=1,
305
  additional_authorized_imports=AUTHORIZED_IMPORTS,
306
  planning_interval=4,
307
  )
308
 
 
309
 
310
  def stream_to_gradio(
311
  agent,
 
313
  reset_agent_memory: bool = False,
314
  additional_args: Optional[dict] = None,
315
  ):
316
+ """Runs an agent with the given task and streams messages as gradio ChatMessages."""
317
+ for step_log in agent.run(
318
+ task, stream=True, reset=reset_agent_memory, additional_args=additional_args
319
+ ):
320
+ for message in pull_messages_from_step(step_log):
321
  yield message
322
 
323
+ # Process final answer
324
  final_answer = step_log # Last log is the run's final_answer
325
  final_answer = handle_agent_output_types(final_answer)
326
 
 
340
  content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
341
  )
342
  else:
343
+ yield gr.ChatMessage(
344
+ role="assistant", content=f"**Final answer:** {str(final_answer)}"
345
+ )
346
 
347
 
348
  class GradioUI:
349
+ """A one-line interface to launch your agent in Gradio."""
350
 
351
  def __init__(self, file_upload_folder: str | None = None):
352
+ """Initialize the Gradio UI with optional file upload functionality."""
353
  self.file_upload_folder = file_upload_folder
354
+
355
  if self.file_upload_folder is not None:
356
  if not os.path.exists(file_upload_folder):
357
  os.mkdir(file_upload_folder)
358
 
359
  def interact_with_agent(self, prompt, messages, session_state):
360
+ """Main interaction handler with the agent."""
361
  # Get or create session-specific agent
362
+ if "agent" not in session_state:
363
+ session_state["agent"] = create_agent()
364
 
365
  # Adding monitoring
366
  try:
367
+ # Log the existence of agent memory
368
+ has_memory = hasattr(session_state["agent"], "memory")
369
  print(f"Agent has memory: {has_memory}")
370
  if has_memory:
371
  print(f"Memory type: {type(session_state['agent'].memory)}")
372
+
373
  messages.append(gr.ChatMessage(role="user", content=prompt))
374
  yield messages
375
+
376
+ for msg in stream_to_gradio(
377
+ session_state["agent"], task=prompt, reset_agent_memory=False
378
+ ):
379
  messages.append(msg)
380
  yield messages
381
  yield messages
 
393
  "text/plain",
394
  ],
395
  ):
396
+ """Handle file uploads with proper validation and security."""
 
 
397
  if file is None:
398
  return gr.Textbox("No file uploaded", visible=True), file_uploads_log
399
 
 
409
  original_name = os.path.basename(file.name)
410
  sanitized_name = re.sub(
411
  r"[^\w\-.]", "_", original_name
412
+ ) # Replace invalid chars with underscores
413
 
414
+ # Ensure the extension correlates to the mime type
415
  type_to_ext = {}
416
  for ext, t in mimetypes.types_map.items():
417
  if t not in type_to_ext:
418
  type_to_ext[t] = ext
419
 
420
+ # Build sanitized filename with proper extension
421
+ name_parts = sanitized_name.split(".")[:-1]
422
+ extension = type_to_ext.get(mime_type, "")
423
+ sanitized_name = "".join(name_parts) + extension
424
 
425
  # Save the uploaded file to the specified folder
426
+ file_path = os.path.join(self.file_upload_folder, sanitized_name)
427
  shutil.copy(file.name, file_path)
428
 
429
+ return gr.Textbox(
430
+ f"File uploaded: {file_path}", visible=True
431
+ ), file_uploads_log + [file_path]
432
 
433
  def log_user_message(self, text_input, file_uploads_log):
434
+ """Process user message and handle file references."""
435
+ message = text_input
436
+
437
+ if len(file_uploads_log) > 0:
438
+ message += f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
439
+
440
  return (
441
+ message,
442
+ gr.Textbox(
443
+ value="",
444
+ interactive=False,
445
+ placeholder="Please wait while Steps are getting populated",
446
  ),
447
+ gr.Button(interactive=False),
 
448
  )
449
 
450
  def detect_device(self, request: gr.Request):
451
+ """Detect whether the user is on mobile or desktop device."""
 
452
  if not request:
453
  return "Unknown device"
454
+
455
  # Method 1: Check sec-ch-ua-mobile header
456
+ is_mobile_header = request.headers.get("sec-ch-ua-mobile")
457
  if is_mobile_header:
458
+ return "Mobile" if "?1" in is_mobile_header else "Desktop"
459
+
460
  # Method 2: Check user-agent string
461
+ user_agent = request.headers.get("user-agent", "").lower()
462
+ mobile_keywords = ["android", "iphone", "ipad", "mobile", "phone"]
463
+
464
  if any(keyword in user_agent for keyword in mobile_keywords):
465
  return "Mobile"
466
+
467
  # Method 3: Check platform
468
+ platform = request.headers.get("sec-ch-ua-platform", "").lower()
469
  if platform:
470
  if platform in ['"android"', '"ios"']:
471
  return "Mobile"
472
  elif platform in ['"windows"', '"macos"', '"linux"']:
473
  return "Desktop"
474
+
475
  # Default case if no clear indicators
476
+ return "Desktop"
 
 
477
 
478
+ def launch(self, **kwargs):
479
+ """Launch the Gradio UI with responsive layout."""
480
  with gr.Blocks(theme="ocean", fill_height=True) as demo:
481
  # Different layouts for mobile and computer devices
482
  @gr.render()
483
  def layout(request: gr.Request):
484
  device = self.detect_device(request)
485
  print(f"device - {device}")
486
+ # Render layout with sidebar
487
  if device == "Desktop":
488
+ return self._create_desktop_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  else:
490
+ return self._create_mobile_layout()
491
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  demo.launch(debug=True, **kwargs)
493
 
494
+ def _create_desktop_layout(self):
495
+ """Create the desktop layout with sidebar."""
496
+ with gr.Blocks(fill_height=True) as sidebar_demo:
497
+ with gr.Sidebar():
498
+ gr.Markdown("""#OpenDeepResearch - free the AI agents!""")
499
+ with gr.Group():
500
+ gr.Markdown("**What's on your mind mate?**", container=True)
501
+ text_input = gr.Textbox(
502
+ lines=3,
503
+ label="Your request",
504
+ container=False,
505
+ placeholder="Enter your prompt here and press Shift+Enter or press the button",
506
+ )
507
+ launch_research_btn = gr.Button("Run", variant="primary")
508
+
509
+ # If an upload folder is provided, enable the upload feature
510
+ if self.file_upload_folder is not None:
511
+ upload_file = gr.File(label="Upload a file")
512
+ upload_status = gr.Textbox(
513
+ label="Upload Status", interactive=False, visible=False
514
+ )
515
+ file_uploads_log = gr.State([])
516
+ upload_file.change(
517
+ self.upload_file,
518
+ [upload_file, file_uploads_log],
519
+ [upload_status, file_uploads_log],
520
+ )
521
+
522
+ gr.HTML("<br><br><h4><center>Powered by:</center></h4>")
523
+ with gr.Row():
524
+ gr.HTML(
525
+ """
526
+ <div style="display: flex; align-items: center; gap: 8px; font-family: system-ui, -apple-system, sans-serif;">
527
+ <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png"
528
+ style="width: 32px; height: 32px; object-fit: contain;" alt="logo">
529
+ <a target="_blank" href="https://github.com/huggingface/smolagents">
530
+ <b>huggingface/smolagents</b>
531
+ </a>
532
+ </div>
533
+ """
534
+ )
535
+
536
+ # Add session state to store session-specific data
537
+ session_state = gr.State({}) # Initialize empty state for each session
538
+ stored_messages = gr.State([])
539
+ if not "file_uploads_log" in locals():
540
+ file_uploads_log = gr.State([])
541
+
542
+ chatbot = gr.Chatbot(
543
+ label="open-Deep-Research",
544
+ type="messages",
545
+ avatar_images=(
546
+ None,
547
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
548
+ ),
549
+ resizeable=False,
550
+ scale=1,
551
+ elem_id="my-chatbot",
552
+ )
553
+
554
+ self._connect_event_handlers(
555
+ text_input,
556
+ launch_research_btn,
557
+ file_uploads_log,
558
+ stored_messages,
559
+ chatbot,
560
+ session_state,
561
+ )
562
+
563
+ return sidebar_demo
564
+
565
+ def _create_mobile_layout(self):
566
+ """Create the mobile layout (simpler without sidebar)."""
567
+ with gr.Blocks(fill_height=True) as simple_demo:
568
+ gr.Markdown("""#OpenDeepResearch - free the AI agents!""")
569
+ # Add session state to store session-specific data
570
+ session_state = gr.State({})
571
+ stored_messages = gr.State([])
572
+ file_uploads_log = gr.State([])
573
+
574
+ chatbot = gr.Chatbot(
575
+ label="open-Deep-Research",
576
+ type="messages",
577
+ avatar_images=(
578
+ None,
579
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
580
+ ),
581
+ resizeable=True,
582
+ scale=1,
583
+ )
584
+
585
+ # If an upload folder is provided, enable the upload feature
586
+ if self.file_upload_folder is not None:
587
+ upload_file = gr.File(label="Upload a file")
588
+ upload_status = gr.Textbox(
589
+ label="Upload Status", interactive=False, visible=False
590
+ )
591
+ upload_file.change(
592
+ self.upload_file,
593
+ [upload_file, file_uploads_log],
594
+ [upload_status, file_uploads_log],
595
+ )
596
+
597
+ text_input = gr.Textbox(
598
+ lines=1,
599
+ label="What's on your mind mate?",
600
+ placeholder="Chuck in a question and we'll take care of the rest",
601
+ )
602
+ launch_research_btn = gr.Button("Run", variant="primary")
603
+
604
+ self._connect_event_handlers(
605
+ text_input,
606
+ launch_research_btn,
607
+ file_uploads_log,
608
+ stored_messages,
609
+ chatbot,
610
+ session_state,
611
+ )
612
+
613
+ return simple_demo
614
+
615
+ def _connect_event_handlers(
616
+ self,
617
+ text_input,
618
+ launch_research_btn,
619
+ file_uploads_log,
620
+ stored_messages,
621
+ chatbot,
622
+ session_state,
623
+ ):
624
+ """Connect the event handlers for input elements."""
625
+ # Connect text input submit event
626
+ text_input.submit(
627
+ self.log_user_message,
628
+ [text_input, file_uploads_log],
629
+ [stored_messages, text_input, launch_research_btn],
630
+ ).then(
631
+ self.interact_with_agent,
632
+ [stored_messages, chatbot, session_state],
633
+ [chatbot],
634
+ ).then(
635
+ lambda: (
636
+ gr.Textbox(
637
+ interactive=True,
638
+ placeholder="Enter your prompt here and press the button",
639
+ ),
640
+ gr.Button(interactive=True),
641
+ ),
642
+ None,
643
+ [text_input, launch_research_btn],
644
+ )
645
+
646
+ # Connect button click event
647
+ launch_research_btn.click(
648
+ self.log_user_message,
649
+ [text_input, file_uploads_log],
650
+ [stored_messages, text_input, launch_research_btn],
651
+ ).then(
652
+ self.interact_with_agent,
653
+ [stored_messages, chatbot, session_state],
654
+ [chatbot],
655
+ ).then(
656
+ lambda: (
657
+ gr.Textbox(
658
+ interactive=True,
659
+ placeholder="Enter your prompt here and press the button",
660
+ ),
661
+ gr.Button(interactive=True),
662
+ ),
663
+ None,
664
+ [text_input, launch_research_btn],
665
+ )
666
+
667
+
668
+ def main():
669
+ """Main entry point for the application."""
670
+ # Initialize environment
671
+ setup_environment()
672
+
673
+ # Ensure downloads folder exists
674
+ os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
675
+
676
+ # Launch UI
677
+ GradioUI().launch()
678
+
679
+
680
+ if __name__ == "__main__":
681
+ main()