Leonardo commited on
Commit
d2dfe27
·
verified ·
1 Parent(s): aa8f759

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -474
app.py CHANGED
@@ -1,8 +1,8 @@
 
1
  import os
 
2
  import shutil
3
- import threading
4
- import tempfile
5
- from typing import Optional, List, Dict, Any
6
 
7
  from dotenv import load_dotenv
8
  from huggingface_hub import login
@@ -19,7 +19,7 @@ from scripts.text_web_browser import (
19
  VisitTool,
20
  )
21
  from scripts.visual_qa import visualizer
22
- from scripts.legal_document_tool import LegalDocumentTool
23
  from smolagents import (
24
  CodeAgent,
25
  HfApiModel,
@@ -58,8 +58,6 @@ AUTHORIZED_IMPORTS = [
58
  "fractions",
59
  "csv",
60
  "clean-text",
61
- "langchain",
62
- "llama_index",
63
  ]
64
 
65
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
@@ -75,20 +73,19 @@ BROWSER_CONFIG = {
75
 
76
  custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
77
 
78
- # Multimedia file types supported (using Gradio-compatible format)
79
  ALLOWED_FILE_TYPES = [
80
- ".pdf", # application/pdf
81
- ".docx", # application/vnd.openxmlformats-officedocument.wordprocessingml.document
82
- ".txt", # text/plain
83
- ".png", # image/png
84
- ".webp", # image/webp
85
- ".jpeg", # image/jpeg
86
- ".jpg", # image/jpeg
87
- ".gif", # image/gif
88
- ".mp4", # video/mp4
89
- ".mp3", # audio/mpeg
90
- ".wav", # audio/wav
91
- ".ogg", # audio/ogg
92
  ]
93
 
94
 
@@ -104,52 +101,31 @@ def setup_environment():
104
 
105
  # ------------------------ Model and Tool Management ------------------------
106
  class ModelManager:
107
- """Manages model loading and initialization with Zhou Protocol patterns."""
108
-
109
- _instance = None
110
- _lock = threading.Lock()
111
-
112
- @classmethod
113
- def get_instance(cls):
114
- """Thread-safe singleton access to model manager."""
115
- if cls._instance is None:
116
- with cls._lock:
117
- if cls._instance is None:
118
- cls._instance = cls()
119
- return cls._instance
120
-
121
- def __init__(self):
122
- """Initialize with model cache."""
123
- self.model_cache = {}
124
-
125
- def load_model(self, chosen_inference: str, model_id: str, key_manager=None):
126
- """Load the specified model with appropriate configuration and caching."""
127
- cache_key = f"{chosen_inference}:{model_id}"
128
-
129
- # Return cached model if available
130
- if cache_key in self.model_cache:
131
- return self.model_cache[cache_key]
132
 
 
 
 
133
  try:
134
  if chosen_inference == "hf_api":
135
- model = HfApiModel(model_id=model_id)
136
 
137
  elif chosen_inference == "hf_api_provider":
138
- model = HfApiModel(provider="together")
139
 
140
  elif chosen_inference == "litellm":
141
- model = LiteLLMModel(model_id=model_id)
142
 
143
  elif chosen_inference == "openai":
144
  if not key_manager:
145
  raise ValueError("Key manager required for OpenAI model")
146
 
147
- model = OpenAIServerModel(
148
  model_id=model_id, api_key=key_manager.get_key("openai_api_key")
149
  )
150
 
151
  elif chosen_inference == "transformers":
152
- model = TransformersModel(
153
  model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
154
  device_map="auto",
155
  max_new_tokens=1000,
@@ -158,22 +134,13 @@ class ModelManager:
158
  else:
159
  raise ValueError(f"Invalid inference type: {chosen_inference}")
160
 
161
- # Cache the model for future use
162
- self.model_cache[cache_key] = model
163
- return model
164
-
165
  except Exception as e:
166
  print(f"✗ Couldn't load model: {e}")
167
  raise
168
 
169
 
170
  class ToolRegistry:
171
- """Manages tool initialization and organization with validation."""
172
-
173
- @staticmethod
174
- def validate_tools(tools: List[Tool]) -> List[Tool]:
175
- """Validate tools and filter out any None values."""
176
- return [tool for tool in tools if isinstance(tool, Tool)]
177
 
178
  @staticmethod
179
  def load_web_tools(model, browser, text_limit=20000):
@@ -200,208 +167,51 @@ class ToolRegistry:
200
  )
201
  except Exception as e:
202
  print(f"✗ Couldn't initialize image generation tool: {e}")
203
- return None
204
-
205
- @staticmethod
206
- def load_legal_document_tool():
207
- """Initialize and return the legal document processing tool."""
208
- try:
209
- # Create a simple instance with default parameters
210
- return LegalDocumentTool()
211
- except Exception as e:
212
- print(f"✗ Couldn't initialize legal document tool: {e}")
213
- # Return None instead of raising to make this tool optional
214
- return None
215
-
216
-
217
- # ------------------------ Session Management ------------------------
218
- class SessionManager:
219
- """Manages agent sessions with proper cleanup and lifecycle management."""
220
-
221
- _instance = None
222
- _lock = threading.Lock()
223
-
224
- @classmethod
225
- def get_instance(cls):
226
- """Thread-safe singleton access."""
227
- if cls._instance is None:
228
- with cls._lock:
229
- if cls._instance is None:
230
- cls._instance = cls()
231
- return cls._instance
232
-
233
- def __init__(self):
234
- """Initialize session management structures."""
235
- self.sessions = {}
236
- self.temp_files = {}
237
- self.last_activity = {}
238
- self._cleanup_thread = None
239
- self._running = False
240
- self._start_cleanup_thread()
241
-
242
- def _start_cleanup_thread(self):
243
- """Start a background thread for session cleanup."""
244
- if self._cleanup_thread is None:
245
- self._running = True
246
- self._cleanup_thread = threading.Thread(
247
- target=self._cleanup_inactive_sessions, daemon=True
248
- )
249
- self._cleanup_thread.start()
250
-
251
- def _cleanup_inactive_sessions(self):
252
- """Periodically clean up inactive sessions."""
253
- import time
254
-
255
- # Session timeout in seconds (30 minutes)
256
- SESSION_TIMEOUT = 30 * 60
257
-
258
- while self._running:
259
- current_time = time.time()
260
-
261
- # Find inactive sessions
262
- inactive_sessions = [
263
- session_id
264
- for session_id, last_time in self.last_activity.items()
265
- if (current_time - last_time) > SESSION_TIMEOUT
266
- ]
267
-
268
- # Clean up each inactive session
269
- for session_id in inactive_sessions:
270
- self.cleanup_session(session_id)
271
-
272
- # Sleep for a minute before next check
273
- time.sleep(60)
274
-
275
- def register_session(self, session_id):
276
- """Register a new session."""
277
- if session_id not in self.sessions:
278
- self.sessions[session_id] = {}
279
- self.temp_files[session_id] = []
280
-
281
- # Update activity timestamp
282
- self.last_activity[session_id] = time.time()
283
-
284
- def update_activity(self, session_id):
285
- """Update the last activity timestamp for a session."""
286
- self.last_activity[session_id] = time.time()
287
-
288
- def register_temp_file(self, session_id, file_path):
289
- """Register a temporary file with a session for later cleanup."""
290
- if session_id not in self.temp_files:
291
- self.temp_files[session_id] = []
292
-
293
- self.temp_files[session_id].append(file_path)
294
-
295
- def cleanup_session(self, session_id):
296
- """Clean up resources for a session."""
297
- # Remove temporary files
298
- if session_id in self.temp_files:
299
- for file_path in self.temp_files[session_id]:
300
- try:
301
- if os.path.exists(file_path):
302
- os.remove(file_path)
303
- except Exception as e:
304
- print(f"Error removing temp file {file_path}: {e}")
305
-
306
- del self.temp_files[session_id]
307
-
308
- # Clean up session data
309
- if session_id in self.sessions:
310
- del self.sessions[session_id]
311
-
312
- # Clean up activity record
313
- if session_id in self.last_activity:
314
- del self.last_activity[session_id]
315
-
316
- def __del__(self):
317
- """Clean up all sessions when the manager is destroyed."""
318
- self._running = False
319
- if self._cleanup_thread and self._cleanup_thread.is_alive():
320
- self._cleanup_thread.join(timeout=1.0)
321
-
322
- # Clean up all remaining sessions
323
- for session_id in list(self.sessions.keys()):
324
- self.cleanup_session(session_id)
325
 
326
 
327
  # ------------------------ Agent Creation and Execution ------------------------
328
- class AgentFactory:
329
- """Factory for creating and managing agent instances with Zhou Protocol patterns."""
330
-
331
- _instance = None
332
- _lock = threading.Lock()
333
-
334
- @classmethod
335
- def get_instance(cls):
336
- """Thread-safe singleton access."""
337
- if cls._instance is None:
338
- with cls._lock:
339
- if cls._instance is None:
340
- cls._instance = cls()
341
- return cls._instance
342
-
343
- def __init__(self):
344
- """Initialize with agent cache."""
345
- self.agent_cache = {}
346
-
347
- def create_agent(self, session_id: str = "default") -> CodeAgent:
348
- """Creates a fresh agent instance with properly configured tools."""
349
- # Return cached agent if available for this session
350
- if session_id in self.agent_cache:
351
- return self.agent_cache[session_id]
352
-
353
- # Initialize model
354
- model = LiteLLMModel(
355
- custom_role_conversions=custom_role_conversions,
356
- model_id="openrouter/perplexity/r1-1776", # currently serving:
357
- ) # DEEPSEEK = openrouter/perplexity/r1-1776 <--- boss model
358
-
359
- # Initialize tools
360
- text_limit = 30000
361
- browser = SimpleTextBrowser(**BROWSER_CONFIG)
362
-
363
- # Collect all tools in a single list
364
- web_tools = ToolRegistry.load_web_tools(model, browser, text_limit)
365
- image_generator = ToolRegistry.load_image_generation_tools()
366
- legal_tool = ToolRegistry.load_legal_document_tool()
367
-
368
- # Combine and validate all tools
369
- all_tools = [visualizer] + web_tools
370
-
371
- # Only add tools that are properly initialized (not None)
372
- if image_generator:
373
- all_tools.append(image_generator)
374
-
375
- if legal_tool:
376
- all_tools.append(legal_tool)
377
-
378
- # Final validation to ensure all tools are valid
379
- all_tools = ToolRegistry.validate_tools(all_tools)
380
-
381
- agent = CodeAgent(
382
- model=model,
383
- tools=all_tools, # Pass a single list containing all tools
384
- max_steps=10,
385
- verbosity_level=1,
386
- additional_authorized_imports=AUTHORIZED_IMPORTS,
387
- planning_interval=4,
388
- )
389
-
390
- # Cache the agent for future use
391
- self.agent_cache[session_id] = agent
392
- return agent
393
 
394
- def clear_agent(self, session_id: str):
395
- """Remove an agent from the cache."""
396
- if session_id in self.agent_cache:
397
- del self.agent_cache[session_id]
 
 
 
 
398
 
399
 
400
  def stream_to_gradio(
401
  agent,
402
  task: str,
403
  reset_agent_memory: bool = False,
404
- additional_args: Optional[Dict[str, Any]] = None,
405
  ):
406
  """Runs an agent with the given task and streams messages as Gradio ChatMessages."""
407
  for step_log in agent.run(
@@ -437,296 +247,337 @@ def stream_to_gradio(
437
 
438
  # ------------------------ Gradio UI Components ------------------------
439
  class GradioUI:
440
- """A Gradio-compliant interface to launch your agent with proper resource management."""
441
 
442
- def __init__(self):
443
- """Initialize the Gradio UI with proper session management."""
444
- self.session_manager = SessionManager.get_instance()
445
- self.agent_factory = AgentFactory.get_instance()
446
- self.temp_dir = tempfile.mkdtemp(prefix="gradio_")
447
 
448
- def __del__(self):
449
- """Clean up resources when the UI is destroyed."""
450
- try:
451
- # Clean up the temporary directory
452
- if os.path.exists(self.temp_dir):
453
- shutil.rmtree(self.temp_dir, ignore_errors=True)
454
- except Exception as e:
455
- print(f"Error cleaning up temporary directory: {e}")
456
-
457
- @staticmethod
458
- def _get_session_id(session_state):
459
- """Generate or retrieve a session ID."""
460
- if "session_id" not in session_state:
461
- session_state["session_id"] = f"session_{id(session_state)}"
462
- return session_state["session_id"]
463
 
464
  def interact_with_agent(self, prompt, messages, session_state):
465
  """Main interaction handler with the agent."""
466
- # Get or create session ID
467
- session_id = self._get_session_id(session_state)
468
-
469
- # Register/update the session
470
- self.session_manager.register_session(session_id)
471
- self.session_manager.update_activity(session_id)
472
 
473
  # Get or create session-specific agent
474
- agent = self.agent_factory.create_agent(session_id)
 
475
 
 
476
  try:
477
  # Log the existence of agent memory
478
- has_memory = hasattr(agent, "memory")
479
  print(f"Agent has memory: {has_memory}")
480
  if has_memory:
481
- print(f"Memory type: {type(agent.memory)}")
482
 
483
  messages.append(gr.ChatMessage(role="user", content=prompt))
484
  yield messages
485
 
486
- for msg in stream_to_gradio(agent, task=prompt, reset_agent_memory=False):
 
 
487
  messages.append(msg)
488
- self.session_manager.update_activity(session_id)
489
  yield messages # Yield messages after each step
490
-
491
  yield messages # Yield messages one last time
492
 
493
- except gr.Error as e:
494
- # Handle Gradio-specific errors
495
- messages.append(
496
- gr.ChatMessage(role="assistant", content=f"Error: {str(e)}")
497
- )
498
- yield messages
499
  except Exception as e:
500
- # Log the error but present a user-friendly message
501
  print(f"Error in interaction: {str(e)}")
502
- messages.append(
503
- gr.ChatMessage(
504
- role="assistant",
505
- content="I encountered an error processing your request. Please try again with a different query.",
506
- )
507
- )
508
- yield messages
509
-
510
- @gr.validate_input(file="file")
511
- def upload_file(self, file, file_uploads_log, session_state):
512
- """Handle file uploads with Gradio-compliant temporary file handling."""
513
- session_id = self._get_session_id(session_state)
514
- self.session_manager.update_activity(session_id)
515
 
 
 
 
 
 
 
516
  if file is None:
517
  return gr.Textbox("No file uploaded", visible=True), file_uploads_log
518
 
519
  try:
520
- # Create a temporary file with a secure random name
521
- temp_file_path = ""
 
522
 
523
- with tempfile.NamedTemporaryFile(delete=False, dir=self.temp_dir) as tmp:
524
- # Copy the uploaded file to the temporary file
525
- shutil.copyfileobj(open(file.name, "rb"), tmp)
526
- temp_file_path = tmp.name
527
 
528
- # Register the temporary file with the session manager
529
- self.session_manager.register_temp_file(session_id, temp_file_path)
 
 
 
530
 
531
- # Store the original filename for reference
532
- orig_filename = os.path.basename(file.name)
 
 
 
533
 
534
- return (
535
- gr.Textbox(f"File uploaded: {orig_filename}", visible=True),
536
- file_uploads_log + [(temp_file_path, orig_filename)],
537
- )
538
 
539
- except Exception as e:
540
- print(f"Error handling file upload: {e}")
 
 
 
541
  return (
542
- gr.Textbox(f"Error uploading file: {str(e)}", visible=True),
 
 
543
  file_uploads_log,
544
  )
545
 
546
- def log_user_message(self, text_input, file_uploads_log, session_state):
547
- """Process user message and handle file references."""
548
- session_id = self._get_session_id(session_state)
549
- self.session_manager.update_activity(session_id)
550
-
551
- message = text_input
552
 
553
- if file_uploads_log and len(file_uploads_log) > 0:
554
- # Include only the original filenames in the message
555
- filenames = [f[1] for f in file_uploads_log]
556
- message += f"\nYou have been provided with these files: {filenames}"
557
 
558
- # Include the actual file paths in the additional_args
559
- if "additional_args" not in session_state:
560
- session_state["additional_args"] = {}
561
 
562
- session_state["additional_args"]["file_paths"] = [
563
- f[0] for f in file_uploads_log
564
- ]
565
 
566
  return (
567
  message,
568
  gr.Textbox(
569
  value="",
570
  interactive=False,
571
- placeholder="Processing...",
572
  ),
573
  gr.Button(interactive=False),
574
- session_state,
575
  )
576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  def launch(self, **kwargs):
578
  """Launch the Gradio UI with responsive layout."""
579
- with gr.Blocks(theme="soft", css=self._get_responsive_css()) as demo:
580
- with gr.Row(equal_height=True) as main_row:
581
- # Sidebar (adapts to screen size via CSS)
582
- with gr.Column(scale=1, min_width=100) as sidebar:
583
- gr.Markdown(
584
- """# OpenDeepResearch
585
- ## Powered by Smolagents"""
586
- )
587
- with gr.Group():
588
- gr.Markdown("**What's on your mind?**", container=True)
589
- text_input = gr.Textbox(
590
- lines=3,
591
- label="Your request",
592
- container=False,
593
- placeholder="Enter your prompt here and press Shift+Enter or press the button",
594
- )
595
- launch_research_btn = gr.Button("Run", variant="primary")
596
-
597
- # Clean file upload with Gradio-compliant file_types
598
- upload_file = gr.File(
599
- label="Upload a file",
600
- file_types=ALLOWED_FILE_TYPES,
601
- type="file",
 
 
 
 
 
 
 
 
602
  )
 
 
 
 
 
603
  upload_status = gr.Textbox(
604
  label="Upload Status", interactive=False, visible=False
605
  )
606
-
607
- # Footer with proper responsive behavior
608
- with gr.Row(visible=True) as footer:
609
- gr.HTML(
610
- """
611
- <div style="display: flex; align-items: center; gap: 8px; font-family: system-ui, -apple-system, sans-serif;">
612
- <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png"
613
- style="width: 32px; height: 32px; object-fit: contain;" alt="logo">
614
- <a target="_blank" href="https://github.com/huggingface/smolagents">
615
- <b>huggingface/smolagents</b>
616
- </a>
617
- </div>
618
- """
619
- )
620
-
621
- # Main content area
622
- with gr.Column(scale=4, min_width=400) as content:
623
- # Add session state to store session-specific data
624
- session_state = gr.State({})
625
- stored_messages = gr.State([])
626
  file_uploads_log = gr.State([])
 
 
 
 
 
627
 
628
- chatbot = gr.Chatbot(
629
- label="OpenDeepResearch",
630
- show_label=True,
631
- type="messages",
632
- avatar_images=(
633
- None,
634
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
635
- ),
636
- height=600,
637
- elem_id="research-chatbot",
 
 
638
  )
639
 
640
- # Connect event handlers
641
- text_input.submit(
642
- self.log_user_message,
643
- [text_input, file_uploads_log, session_state],
644
- [stored_messages, text_input, launch_research_btn, session_state],
645
- ).then(
646
- self.interact_with_agent,
647
- [stored_messages, chatbot, session_state],
648
- [chatbot],
649
- ).then(
650
- lambda: (
651
- gr.Textbox(
652
- interactive=True,
653
- placeholder="Enter your prompt here and press the button",
654
- ),
655
- gr.Button(interactive=True),
656
  ),
657
- None,
658
- [text_input, launch_research_btn],
 
659
  )
660
 
661
- launch_research_btn.click(
662
- self.log_user_message,
663
- [text_input, file_uploads_log, session_state],
664
- [stored_messages, text_input, launch_research_btn, session_state],
665
- ).then(
666
- self.interact_with_agent,
667
- [stored_messages, chatbot, session_state],
668
- [chatbot],
669
- ).then(
670
- lambda: (
671
- gr.Textbox(
672
- interactive=True,
673
- placeholder="Enter your prompt here and press the button",
674
- ),
675
- gr.Button(interactive=True),
 
 
 
 
 
 
 
 
 
 
 
676
  ),
677
- None,
678
- [text_input, launch_research_btn],
679
  )
680
 
681
- upload_file.change(
682
- self.upload_file,
683
- [upload_file, file_uploads_log, session_state],
684
- [upload_status, file_uploads_log],
 
 
 
 
 
 
 
 
 
 
 
 
685
  )
 
686
 
687
- # Clean up session on page unload
688
- demo.load(
689
- lambda: None,
690
- None,
691
- None,
692
- _js="""
693
- () => {
694
- window.addEventListener('beforeunload', function() {
695
- // Notify backend about session end (would require additional endpoint)
696
- console.log('Cleaning up session');
697
- });
698
- }
699
- """,
700
  )
701
 
702
- demo.queue(max_size=20).launch(debug=True, **kwargs)
703
-
704
- def _get_responsive_css(self):
705
- """Get CSS for responsive layout."""
706
- return """
707
- /* Responsive layout */
708
- @media (max-width: 768px) {
709
- #research-chatbot {
710
- height: 400px !important;
711
- }
712
-
713
- /* Stack columns on small screens */
714
- .gradio-row {
715
- flex-direction: column;
716
- }
717
-
718
- /* Adjust column widths */
719
- .gradio-column {
720
- min-width: 100% !important;
721
- width: 100% !important;
722
- }
723
- }
724
-
725
- /* Base styling */
726
- .gradio-container {
727
- max-width: 100% !important;
728
- }
729
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
 
731
 
732
  # ------------------------ Execution ------------------------
@@ -735,11 +586,11 @@ def main():
735
  # Initialize environment
736
  setup_environment()
737
 
738
- # Ensure downloads folder exists for browser
739
  os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
740
 
741
  # Launch UI
742
- GradioUI().launch(share=True)
743
 
744
 
745
  if __name__ == "__main__":
 
1
+ import mimetypes
2
  import os
3
+ import re
4
  import shutil
5
+ from typing import Optional
 
 
6
 
7
  from dotenv import load_dotenv
8
  from huggingface_hub import login
 
19
  VisitTool,
20
  )
21
  from scripts.visual_qa import visualizer
22
+
23
  from smolagents import (
24
  CodeAgent,
25
  HfApiModel,
 
58
  "fractions",
59
  "csv",
60
  "clean-text",
 
 
61
  ]
62
 
63
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
 
73
 
74
  custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
75
 
76
+ # Multimedia file types supported:
77
  ALLOWED_FILE_TYPES = [
78
+ "application/pdf",
79
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
80
+ "text/plain",
81
+ "image/png",
82
+ "image/webp",
83
+ "image/jpeg", # Added JPEG support
84
+ "image/gif", # Added GIF support
85
+ "video/mp4",
86
+ "audio/mpeg", # Added MP3 support
87
+ "audio/wav", # Added WAV support
88
+ "audio/ogg", # Added OGG support
 
89
  ]
90
 
91
 
 
101
 
102
  # ------------------------ Model and Tool Management ------------------------
103
  class ModelManager:
104
+ """Manages model loading and initialization."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ @staticmethod
107
+ def load_model(chosen_inference: str, model_id: str, key_manager=None):
108
+ """Load the specified model with appropriate configuration."""
109
  try:
110
  if chosen_inference == "hf_api":
111
+ return HfApiModel(model_id=model_id)
112
 
113
  elif chosen_inference == "hf_api_provider":
114
+ return HfApiModel(provider="together")
115
 
116
  elif chosen_inference == "litellm":
117
+ return LiteLLMModel(model_id=model_id)
118
 
119
  elif chosen_inference == "openai":
120
  if not key_manager:
121
  raise ValueError("Key manager required for OpenAI model")
122
 
123
+ return OpenAIServerModel(
124
  model_id=model_id, api_key=key_manager.get_key("openai_api_key")
125
  )
126
 
127
  elif chosen_inference == "transformers":
128
+ return TransformersModel(
129
  model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
130
  device_map="auto",
131
  max_new_tokens=1000,
 
134
  else:
135
  raise ValueError(f"Invalid inference type: {chosen_inference}")
136
 
 
 
 
 
137
  except Exception as e:
138
  print(f"✗ Couldn't load model: {e}")
139
  raise
140
 
141
 
142
  class ToolRegistry:
143
+ """Manages tool initialization and organization."""
 
 
 
 
 
144
 
145
  @staticmethod
146
  def load_web_tools(model, browser, text_limit=20000):
 
167
  )
168
  except Exception as e:
169
  print(f"✗ Couldn't initialize image generation tool: {e}")
170
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  # ------------------------ Agent Creation and Execution ------------------------
174
+ def create_agent():
175
+ """Creates a fresh agent instance with properly configured tools."""
176
+ # Initialize model
177
+ model = LiteLLMModel(
178
+ custom_role_conversions=custom_role_conversions,
179
+ model_id="openrouter/google/gemini-2.0-flash-001", # currently serving:
180
+ ) # DEEPSEEK = openrouter/perplexity/r1-1776 <--- boss model
181
+
182
+ # Initialize tools
183
+ text_limit = 30000
184
+ browser = SimpleTextBrowser(**BROWSER_CONFIG)
185
+
186
+ # Collect all tools in a single list
187
+ web_tools = ToolRegistry.load_web_tools(model, browser, text_limit)
188
+ image_generator = ToolRegistry.load_image_generation_tools()
189
+
190
+ # Combine all tools into a single list (not a tuple)
191
+ all_tools = [visualizer] + web_tools + [image_generator]
192
+
193
+ # Validate tools before creating agent
194
+ for tool in all_tools:
195
+ if not isinstance(tool, Tool):
196
+ raise ValueError(
197
+ f"Invalid tool type: {type(tool)}. All tools must be instances of Tool class."
198
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ return CodeAgent(
201
+ model=model,
202
+ tools=all_tools, # Pass a single list containing all tools
203
+ max_steps=10,
204
+ verbosity_level=1,
205
+ additional_authorized_imports=AUTHORIZED_IMPORTS,
206
+ planning_interval=4,
207
+ )
208
 
209
 
210
  def stream_to_gradio(
211
  agent,
212
  task: str,
213
  reset_agent_memory: bool = False,
214
+ additional_args: Optional[dict] = None,
215
  ):
216
  """Runs an agent with the given task and streams messages as Gradio ChatMessages."""
217
  for step_log in agent.run(
 
247
 
248
  # ------------------------ Gradio UI Components ------------------------
249
  class GradioUI:
250
+ """A one-line interface to launch your agent in Gradio."""
251
 
252
+ def __init__(self, file_upload_folder: str | None = None):
253
+ """Initialize the Gradio UI with optional file upload functionality."""
254
+ self.file_upload_folder = file_upload_folder
 
 
255
 
256
+ if self.file_upload_folder is not None:
257
+ if not os.path.exists(file_upload_folder):
258
+ os.mkdir(file_upload_folder)
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  def interact_with_agent(self, prompt, messages, session_state):
261
  """Main interaction handler with the agent."""
 
 
 
 
 
 
262
 
263
  # Get or create session-specific agent
264
+ if "agent" not in session_state:
265
+ session_state["agent"] = create_agent()
266
 
267
+ # Adding monitoring
268
  try:
269
  # Log the existence of agent memory
270
+ has_memory = hasattr(session_state["agent"], "memory")
271
  print(f"Agent has memory: {has_memory}")
272
  if has_memory:
273
+ print(f"Memory type: {type(session_state['agent'].memory)}")
274
 
275
  messages.append(gr.ChatMessage(role="user", content=prompt))
276
  yield messages
277
 
278
+ for msg in stream_to_gradio(
279
+ session_state["agent"], task=prompt, reset_agent_memory=False
280
+ ):
281
  messages.append(msg)
 
282
  yield messages # Yield messages after each step
 
283
  yield messages # Yield messages one last time
284
 
 
 
 
 
 
 
285
  except Exception as e:
 
286
  print(f"Error in interaction: {str(e)}")
287
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
+ def upload_file(
290
+ self,
291
+ file,
292
+ file_uploads_log,
293
+ ):
294
+ """Handle file uploads with proper validation and security."""
295
  if file is None:
296
  return gr.Textbox("No file uploaded", visible=True), file_uploads_log
297
 
298
  try:
299
+ mime_type, _ = mimetypes.guess_type(file.name)
300
+ except Exception as e:
301
+ return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
302
 
303
+ if mime_type not in ALLOWED_FILE_TYPES:
304
+ return gr.Textbox("File type disallowed", visible=True), file_uploads_log
 
 
305
 
306
+ # Sanitize file name
307
+ original_name = os.path.basename(file.name)
308
+ sanitized_name = re.sub(
309
+ r"[^\w\-.]", "_", original_name
310
+ ) # Replace invalid chars with underscores
311
 
312
+ # Ensure the extension correlates to the mime type
313
+ type_to_ext = {}
314
+ for ext, t in mimetypes.types_map.items():
315
+ if t not in type_to_ext:
316
+ type_to_ext[t] = ext
317
 
318
+ # Build sanitized filename with proper extension
319
+ name_parts = sanitized_name.split(".")[:-1]
320
+ extension = type_to_ext.get(mime_type, "")
321
+ sanitized_name = "".join(name_parts) + extension
322
 
323
+ # Limit File Size, and Throw Error
324
+ max_file_size_mb = 50 # Define the limit
325
+ file_size_mb = os.path.getsize(file.name) / (1024 * 1024) # Size in MB
326
+
327
+ if file_size_mb > max_file_size_mb:
328
  return (
329
+ gr.Textbox(
330
+ f"File size exceeds {max_file_size_mb} MB limit.", visible=True
331
+ ),
332
  file_uploads_log,
333
  )
334
 
335
+ # Save the uploaded file to the specified folder
336
+ file_path = os.path.join(self.file_upload_folder, sanitized_name)
337
+ shutil.copy(file.name, file_path)
 
 
 
338
 
339
+ return gr.Textbox(
340
+ f"File uploaded: {file_path}", visible=True
341
+ ), file_uploads_log + [file_path]
 
342
 
343
+ def log_user_message(self, text_input, file_uploads_log):
344
+ """Process user message and handle file references."""
345
+ message = text_input
346
 
347
+ if len(file_uploads_log) > 0:
348
+ message += f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" # Added file list
 
349
 
350
  return (
351
  message,
352
  gr.Textbox(
353
  value="",
354
  interactive=False,
355
+ placeholder="Processing...", # Changed placeholder.
356
  ),
357
  gr.Button(interactive=False),
 
358
  )
359
 
360
+ def detect_device(self, request: gr.Request):
361
+ """Detect whether the user is on mobile or desktop device."""
362
+ if not request:
363
+ return "Unknown device" # Handle case where request is none.
364
+
365
+ # Method 1: Check sec-ch-ua-mobile header
366
+ is_mobile_header = request.headers.get("sec-ch-ua-mobile")
367
+ if is_mobile_header:
368
+ return "Mobile" if "?1" in is_mobile_header else "Desktop"
369
+
370
+ # Method 2: Check user-agent string
371
+ user_agent = request.headers.get("user-agent", "").lower()
372
+ mobile_keywords = ["android", "iphone", "ipad", "mobile", "phone"]
373
+
374
+ if any(keyword in user_agent for keyword in mobile_keywords):
375
+ return "Mobile"
376
+
377
+ # Method 3: Check platform
378
+ platform = request.headers.get("sec-ch-ua-platform", "").lower()
379
+ if platform:
380
+ if platform in ['"android"', '"ios"']:
381
+ return "Mobile"
382
+ elif platform in ['"windows"', '"macos"', '"linux"']:
383
+ return "Desktop"
384
+
385
+ # Default case if no clear indicators
386
+ return "Desktop"
387
+
388
  def launch(self, **kwargs):
389
  """Launch the Gradio UI with responsive layout."""
390
+ with gr.Blocks(theme="ocean", fill_height=True) as demo:
391
+ # Different layouts for mobile and computer devices
392
+ @gr.render()
393
+ def layout(request: gr.Request):
394
+ device = self.detect_device(request)
395
+ print(f"device - {device}")
396
+ # Render layout with sidebar
397
+ if device == "Desktop":
398
+ return self._create_desktop_layout()
399
+ else:
400
+ return self._create_mobile_layout()
401
+
402
+ demo.queue(max_size=20).launch(
403
+ debug=True, **kwargs
404
+ ) # Add queue with reasonable size
405
+
406
+ def _create_desktop_layout(self):
407
+ """Create the desktop layout with sidebar."""
408
+ with gr.Blocks(fill_height=True) as sidebar_demo:
409
+ with gr.Sidebar():
410
+ gr.Markdown(
411
+ """#OpenDeepResearch - 3theSmolagents!
412
+ Model_id: google/gemini-2.0-flash-001"""
413
+ )
414
+ with gr.Group():
415
+ gr.Markdown("**What's on your mind mate?**", container=True)
416
+ text_input = gr.Textbox(
417
+ lines=3,
418
+ label="Your request",
419
+ container=False,
420
+ placeholder="Enter your prompt here and press Shift+Enter or press the button",
421
  )
422
+ launch_research_btn = gr.Button("Run", variant="primary")
423
+
424
+ # If an upload folder is provided, enable the upload feature
425
+ if self.file_upload_folder is not None:
426
+ upload_file = gr.File(label="Upload a file")
427
  upload_status = gr.Textbox(
428
  label="Upload Status", interactive=False, visible=False
429
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  file_uploads_log = gr.State([])
431
+ upload_file.change(
432
+ self.upload_file,
433
+ [upload_file, file_uploads_log],
434
+ [upload_status, file_uploads_log],
435
+ )
436
 
437
+ gr.HTML("<br><br><h4><center>Powered by:</center></h4>")
438
+ with gr.Row():
439
+ gr.HTML(
440
+ """
441
+ <div style="display: flex; align-items: center; gap: 8px; font-family: system-ui, -apple-system, sans-serif;">
442
+ <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png"
443
+ style="width: 32px; height: 32px; object-fit: contain;" alt="logo">
444
+ <a target="_blank" href="https://github.com/huggingface/smolagents">
445
+ <b>huggingface/smolagents</b>
446
+ </a>
447
+ </div>
448
+ """
449
  )
450
 
451
+ # Add session state to store session-specific data
452
+ session_state = gr.State({}) # Initialize empty state for each session
453
+ stored_messages = gr.State([])
454
+ if "file_uploads_log" not in locals():
455
+ file_uploads_log = gr.State([])
456
+
457
+ chatbot = gr.Chatbot(
458
+ label="open-Deep-Research",
459
+ type="messages",
460
+ avatar_images=(
461
+ None,
462
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
 
 
 
 
463
  ),
464
+ resizeable=False,
465
+ scale=1,
466
+ elem_id="my-chatbot",
467
  )
468
 
469
+ self._connect_event_handlers(
470
+ text_input,
471
+ launch_research_btn,
472
+ file_uploads_log,
473
+ stored_messages,
474
+ chatbot,
475
+ session_state,
476
+ )
477
+
478
+ return sidebar_demo
479
+
480
+ def _create_mobile_layout(self):
481
+ """Create the mobile layout (simpler without sidebar)."""
482
+ with gr.Blocks(fill_height=True) as simple_demo:
483
+ gr.Markdown("""#OpenDeepResearch - free the AI agents!""")
484
+ # Add session state to store session-specific data
485
+ session_state = gr.State({})
486
+ stored_messages = gr.State([])
487
+ file_uploads_log = gr.State([])
488
+
489
+ chatbot = gr.Chatbot(
490
+ label="open-Deep-Research",
491
+ type="messages",
492
+ avatar_images=(
493
+ None,
494
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
495
  ),
496
+ resizeable=True,
497
+ scale=1,
498
  )
499
 
500
+ # If an upload folder is provided, enable the upload feature
501
+ if self.file_upload_folder is not None:
502
+ upload_file = gr.File(label="Upload a file")
503
+ upload_status = gr.Textbox(
504
+ label="Upload Status", interactive=False, visible=False
505
+ )
506
+ upload_file.change(
507
+ self.upload_file,
508
+ [upload_file, file_uploads_log],
509
+ [upload_status, file_uploads_log],
510
+ )
511
+
512
+ text_input = gr.Textbox(
513
+ lines=1,
514
+ label="What's on your mind mate?",
515
+ placeholder="Chuck in a question and we'll take care of the rest",
516
  )
517
+ launch_research_btn = gr.Button("Run", variant="primary")
518
 
519
+ self._connect_event_handlers(
520
+ text_input,
521
+ launch_research_btn,
522
+ file_uploads_log,
523
+ stored_messages,
524
+ chatbot,
525
+ session_state,
 
 
 
 
 
 
526
  )
527
 
528
+ return simple_demo
529
+
530
+ def _connect_event_handlers(
531
+ self,
532
+ text_input,
533
+ launch_research_btn,
534
+ file_uploads_log,
535
+ stored_messages,
536
+ chatbot,
537
+ session_state,
538
+ ):
539
+ """Connect the event handlers for input elements."""
540
+ # Connect text input submit event
541
+ text_input.submit(
542
+ self.log_user_message,
543
+ [text_input, file_uploads_log],
544
+ [stored_messages, text_input, launch_research_btn],
545
+ ).then(
546
+ self.interact_with_agent,
547
+ [stored_messages, chatbot, session_state],
548
+ [chatbot],
549
+ ).then(
550
+ lambda: (
551
+ gr.Textbox(
552
+ interactive=True,
553
+ placeholder="Enter your prompt here and press the button",
554
+ ),
555
+ gr.Button(interactive=True),
556
+ ),
557
+ None,
558
+ [text_input, launch_research_btn],
559
+ )
560
+
561
+ # Connect button click event
562
+ launch_research_btn.click(
563
+ self.log_user_message,
564
+ [text_input, file_uploads_log],
565
+ [stored_messages, text_input, launch_research_btn],
566
+ ).then(
567
+ self.interact_with_agent,
568
+ [stored_messages, chatbot, session_state],
569
+ [chatbot],
570
+ ).then(
571
+ lambda: (
572
+ gr.Textbox(
573
+ interactive=True,
574
+ placeholder="Enter your prompt here and press the button",
575
+ ),
576
+ gr.Button(interactive=True),
577
+ ),
578
+ None,
579
+ [text_input, launch_research_btn],
580
+ )
581
 
582
 
583
  # ------------------------ Execution ------------------------
 
586
  # Initialize environment
587
  setup_environment()
588
 
589
+ # Ensure downloads folder exists
590
  os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
591
 
592
  # Launch UI
593
+ GradioUI(file_upload_folder="uploaded_files").launch()
594
 
595
 
596
  if __name__ == "__main__":