Leonardo commited on
Commit
d66b099
·
verified ·
1 Parent(s): 30b3d70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -122
app.py CHANGED
@@ -112,6 +112,22 @@ ALLOWED_FILE_TYPES = [
112
  "audio/wav",
113
  "audio/ogg",
114
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
  def setup_environment():
@@ -341,12 +357,17 @@ class GradioUI:
341
  def __init__(self, file_upload_folder: str | None = None):
342
  """Initialize the Gradio UI with optional file upload functionality."""
343
  self.file_upload_folder = file_upload_folder
 
344
 
345
  if self.file_upload_folder is not None:
346
  os.makedirs(self.file_upload_folder, exist_ok=True)
347
 
348
  def interact_with_agent(
349
- self, prompt: str, messages: List[Dict], session_state: Dict
 
 
 
 
350
  ) -> List[Dict]: # Type hints
351
  """Main interaction handler with the agent."""
352
 
@@ -372,12 +393,39 @@ class GradioUI:
372
  messages.append(gr.ChatMessage(role="user", content=prompt))
373
  yield messages
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  try:
376
  # Check if agent should be reset (e.g., if too many requests)
377
  reset_needed = session_state["request_count"] > 15
378
 
379
  for msg in stream_to_gradio(
380
- session_state["agent"], task=prompt, reset_agent_memory=reset_needed
 
 
381
  ):
382
  messages.append(msg)
383
  yield messages
@@ -395,88 +443,63 @@ class GradioUI:
395
  )
396
  yield messages
397
 
398
- def upload_file(self, file, file_uploads_log):
399
- """Handle file uploads with validation, security, and clear feedback."""
400
- if file is None:
401
- return gr.Textbox("No file uploaded", visible=True), file_uploads_log
402
-
403
- try:
404
- # Get file size and check limit before processing
405
- file_size_mb = os.path.getsize(file.name) / (1024 * 1024) # Size in MB
406
- max_file_size_mb = 50 # Define the limit
407
-
408
- if file_size_mb > max_file_size_mb:
409
- return (
410
- gr.Textbox(
411
- f"❌ File size ({file_size_mb:.1f} MB) exceeds {max_file_size_mb} MB limit.",
412
- visible=True,
413
- ),
414
- file_uploads_log,
415
- )
416
-
417
- # Check MIME type
418
- mime_type, _ = mimetypes.guess_type(file.name) # Correct unpacking
419
- if mime_type not in ALLOWED_FILE_TYPES:
420
- return (
421
- gr.Textbox(
422
- f"❌ File type '{mime_type or 'unknown'}' is not allowed. Supported types: {', '.join(t.split('/')[-1] for t in ALLOWED_FILE_TYPES)}",
423
- visible=True,
424
- ),
425
- file_uploads_log,
426
- )
427
 
428
- # Sanitize file name with better pattern
429
- original_name = os.path.basename(file.name)
430
- sanitized_name = re.sub(r"[^\w\-.]", "", original_name)
 
 
431
 
432
- # Save the uploaded file
433
- file_path = os.path.join(self.file_upload_folder, sanitized_name)
434
- shutil.copy(file.name, file_path)
 
435
 
436
- return gr.Textbox(
437
- f"✓ File uploaded successfully: {os.path.basename(file_path)} ({file_size_mb:.1f} MB)",
438
- visible=True,
439
- ), file_uploads_log + [file_path]
440
 
441
- except Exception as e:
442
- return (
443
- gr.Textbox(f"❌ Upload error: {str(e)}", visible=True),
444
- file_uploads_log,
445
- )
446
-
447
- def log_user_message(self, text_input, file_uploads_log):
448
- """Process user message and handle file references with proper agent types."""
449
- message = text_input
450
 
451
- if len(file_uploads_log) > 0:
452
- # Group files by type for better agent processing
453
- file_info = {}
454
- for file_path in file_uploads_log:
455
- ext = os.path.splitext(file_path)[1].lower()
456
- if ext in [".jpg", ".jpeg", ".png", ".gif", ".webp"]:
457
- category = "images"
458
- elif ext in [".mp3", ".wav", ".ogg"]:
459
- category = "audio"
460
  else:
461
- category = "documents"
462
-
463
- if category not in file_info:
464
- file_info[category] = []
465
- file_info[category].append(os.path.basename(file_path))
466
-
467
- # Format file information for the agent
468
- file_message = "\nYou have been provided with these files:\n"
469
- for category, files in file_info.items():
470
- file_message += f"- {category.capitalize()}: {', '.join(files)}\n"
 
 
 
 
 
 
471
 
472
- message += file_message
473
- message += "\nUse inspect_file_as_text for documents, visualizer for images, and the appropriate tools for audio files."
 
474
 
475
- return (
476
- message,
477
- gr.Textbox(value="", interactive=False, placeholder="Processing..."),
478
- gr.Button(interactive=False),
479
- )
 
 
480
 
481
  def detect_device(self, request: gr.Request):
482
  """Detect whether the user is on mobile or desktop device."""
@@ -550,42 +573,17 @@ class GradioUI:
550
  if self.file_upload_folder is not None:
551
  with gr.Group():
552
  gr.Markdown("📎 Upload Documents")
553
- upload_file = gr.File(
 
554
  label="Upload files for analysis",
555
- file_types=[
556
- "pdf",
557
- "docx",
558
- "txt",
559
- "md",
560
- "csv",
561
- "xlsx",
562
- "jpg",
563
- "png",
564
- ],
565
  file_count="multiple",
566
  )
 
567
  upload_status = gr.Textbox(
568
  label="Upload Status", interactive=False, visible=False
569
  )
570
- file_uploads_log = gr.State([])
571
-
572
- # Show uploaded files list
573
- uploaded_files_display = gr.Markdown("No files uploaded yet")
574
-
575
- upload_file.change(
576
- self.upload_file,
577
- [upload_file, file_uploads_log],
578
- [upload_status, file_uploads_log],
579
- ).then(
580
- lambda files: (
581
- "Uploaded Files:\n"
582
- + "\n".join([f"- {os.path.basename(f)}" for f in files])
583
- if files
584
- else "No files uploaded yet"
585
- ),
586
- [file_uploads_log],
587
- [uploaded_files_display],
588
- )
589
 
590
  gr.HTML("<br><hr><h4><center>Powered by:</center></h4>")
591
  with gr.Row():
@@ -604,8 +602,6 @@ class GradioUI:
604
  # Main chat area with improved styling
605
  session_state = gr.State({})
606
  stored_messages = gr.State([])
607
- if "file_uploads_log" not in locals():
608
- file_uploads_log = gr.State([])
609
 
610
  chatbot = gr.Chatbot(
611
  label="OpenDeepResearch Assistant",
@@ -623,19 +619,26 @@ class GradioUI:
623
 
624
  # Connect clear button
625
  clear_btn.click(
626
- lambda: ([], [], {"agent": session_state.get("agent")}),
627
  None,
628
- [chatbot, stored_messages, session_state],
629
  )
630
 
 
 
 
 
 
 
 
631
  # Connect event handlers
632
  self._connect_event_handlers(
633
  text_input,
634
  launch_research_btn,
635
- file_uploads_log,
636
  stored_messages,
637
  chatbot,
638
  session_state,
 
639
  )
640
 
641
  return sidebar_demo
@@ -647,7 +650,13 @@ class GradioUI:
647
  # Add session state to store session-specific data
648
  session_state = gr.State({})
649
  stored_messages = gr.State([])
650
- file_uploads_log = gr.State([])
 
 
 
 
 
 
651
 
652
  chatbot = gr.Chatbot(
653
  label="open-Deep-Research",
@@ -662,14 +671,13 @@ class GradioUI:
662
 
663
  # If an upload folder is provided, enable the upload feature
664
  if self.file_upload_folder is not None:
665
- upload_file = gr.File(label="Upload a file")
666
  upload_status = gr.Textbox(
667
  label="Upload Status", interactive=False, visible=False
668
  )
669
- upload_file.change(
670
  self.upload_file,
671
- [upload_file, file_uploads_log],
672
- [upload_status, file_uploads_log],
673
  )
674
 
675
  text_input = gr.Textbox(
@@ -682,10 +690,10 @@ class GradioUI:
682
  self._connect_event_handlers(
683
  text_input,
684
  launch_research_btn,
685
- file_uploads_log,
686
  stored_messages,
687
  chatbot,
688
  session_state,
 
689
  )
690
 
691
  return simple_demo
@@ -694,20 +702,20 @@ class GradioUI:
694
  self,
695
  text_input,
696
  launch_research_btn,
697
- file_uploads_log,
698
  stored_messages,
699
  chatbot,
700
  session_state,
 
701
  ):
702
  """Connect the event handlers for input elements."""
703
  # Connect text input submit event
704
  text_input.submit(
705
  self.log_user_message,
706
- [text_input, file_uploads_log],
707
  [stored_messages, text_input, launch_research_btn],
708
  ).then(
709
  self.interact_with_agent,
710
- [stored_messages, chatbot, session_state],
711
  [chatbot],
712
  ).then(
713
  lambda: (
@@ -724,11 +732,11 @@ class GradioUI:
724
  # Connect button click event
725
  launch_research_btn.click(
726
  self.log_user_message,
727
- [text_input, file_uploads_log],
728
  [stored_messages, text_input, launch_research_btn],
729
  ).then(
730
  self.interact_with_agent,
731
- [stored_messages, chatbot, session_state],
732
  [chatbot],
733
  ).then(
734
  lambda: (
 
112
  "audio/wav",
113
  "audio/ogg",
114
  ]
115
+ ALLOWED_EXTENSIONS = [
116
+ ".pdf",
117
+ ".docx",
118
+ ".txt",
119
+ ".md",
120
+ ".json",
121
+ ".png",
122
+ ".webp",
123
+ ".jpeg",
124
+ ".jpg",
125
+ ".gif",
126
+ ".mp4",
127
+ ".mpeg",
128
+ ".wav",
129
+ ".ogg",
130
+ ]
131
 
132
 
133
  def setup_environment():
 
357
  def __init__(self, file_upload_folder: str | None = None):
358
  """Initialize the Gradio UI with optional file upload functionality."""
359
  self.file_upload_folder = file_upload_folder
360
+ self.allowed_extensions = ALLOWED_EXTENSIONS # Use the constant
361
 
362
  if self.file_upload_folder is not None:
363
  os.makedirs(self.file_upload_folder, exist_ok=True)
364
 
365
  def interact_with_agent(
366
+ self,
367
+ prompt: str,
368
+ messages: List[Dict],
369
+ session_state: Dict,
370
+ uploaded_files: List[str],
371
  ) -> List[Dict]: # Type hints
372
  """Main interaction handler with the agent."""
373
 
 
393
  messages.append(gr.ChatMessage(role="user", content=prompt))
394
  yield messages
395
 
396
+ # Process files, linking them to the message
397
+ file_message = ""
398
+ if uploaded_files:
399
+ file_info = {}
400
+ for file_path in uploaded_files:
401
+ ext = os.path.splitext(file_path)[1].lower()
402
+ if ext in [".jpg", ".jpeg", ".png", ".gif", ".webp"]:
403
+ category = "images"
404
+ elif ext in [".mp3", ".wav", ".ogg"]:
405
+ category = "audio"
406
+ else:
407
+ category = "documents"
408
+
409
+ if category not in file_info:
410
+ file_info[category] = []
411
+ file_info[category].append(os.path.basename(file_path))
412
+
413
+ # Format file information for the agent
414
+ file_message = "\nYou have been provided with these files:\n"
415
+ for category, files in file_info.items():
416
+ file_message += f"- {category.capitalize()}: {', '.join(files)}\n"
417
+
418
+ prompt_with_files = prompt + file_message
419
+ prompt_with_files += "\nUse inspect_file_as_text for documents, visualizer for images, and the appropriate tools for audio files."
420
+
421
  try:
422
  # Check if agent should be reset (e.g., if too many requests)
423
  reset_needed = session_state["request_count"] > 15
424
 
425
  for msg in stream_to_gradio(
426
+ session_state["agent"],
427
+ task=prompt_with_files,
428
+ reset_agent_memory=reset_needed,
429
  ):
430
  messages.append(msg)
431
  yield messages
 
443
  )
444
  yield messages
445
 
446
+ def log_user_message(self, text_input: str) -> Tuple[str, gr.Textbox, gr.Button]:
447
+ """Process user message log files."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
+ return (
450
+ text_input,
451
+ gr.Textbox(value="", interactive=False, placeholder="Processing..."),
452
+ gr.Button(interactive=False),
453
+ )
454
 
455
+ def upload_file(self, files: List[str]) -> Tuple[str, List[str]]:
456
+ """Handle file uploads with validation, security, and clear feedback."""
457
+ if not files:
458
+ return "No file uploaded", []
459
 
460
+ uploaded_files = []
461
+ error_message = None
 
 
462
 
463
+ for file_path in files: # Iterate through the list of uploaded files
464
+ try:
465
+ # Check file extension
466
+ file_extension = os.path.splitext(file_path)[1].lower()
467
+ if file_extension not in self.allowed_extensions:
468
+ error_message = (
469
+ f"❌ File type '{file_extension}' is not allowed. "
470
+ f"Supported types: {', '.join(ALLOWED_EXTENSIONS)}" # Use Constant
471
+ )
472
 
473
+ return error_message, [] # Return immediately on first error
 
 
 
 
 
 
 
 
474
  else:
475
+ file_size_mb = os.path.getsize(file_path) / (
476
+ 1024 * 1024
477
+ ) # Size in MB
478
+ max_file_size_mb = 50 # Define the limit
479
+
480
+ if file_size_mb > max_file_size_mb:
481
+ error_message = f"❌ File size ({file_size_mb:.1f} MB) exceeds {max_file_size_mb} MB limit."
482
+ return error_message, []
483
+
484
+ sanitized_name = re.sub(
485
+ r"[^\w\-.]", "", os.path.basename(file_path)
486
+ ) # Sanitize
487
+ dest_path = os.path.join(self.file_upload_folder, sanitized_name)
488
+ shutil.copy(file_path, dest_path)
489
+ uploaded_files.append(dest_path)
490
+ print(f"Uploaded {file_path} to {dest_path}")
491
 
492
+ except Exception as e:
493
+ error_message = f" Upload error: {str(e)}"
494
+ return error_message, [] # Immediately return on error
495
 
496
+ if error_message:
497
+ return error_message, []
498
+ else:
499
+ return (
500
+ f"✓ Files uploaded successfully: {', '.join([os.path.basename(f) for f in uploaded_files])}",
501
+ uploaded_files,
502
+ )
503
 
504
  def detect_device(self, request: gr.Request):
505
  """Detect whether the user is on mobile or desktop device."""
 
573
  if self.file_upload_folder is not None:
574
  with gr.Group():
575
  gr.Markdown("📎 Upload Documents")
576
+ # File input with multiple enabled with the allowed extensions
577
+ file_upload = gr.File(
578
  label="Upload files for analysis",
579
+ file_types=self.allowed_extensions,
 
 
 
 
 
 
 
 
 
580
  file_count="multiple",
581
  )
582
+
583
  upload_status = gr.Textbox(
584
  label="Upload Status", interactive=False, visible=False
585
  )
586
+ uploaded_files_state = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
 
588
  gr.HTML("<br><hr><h4><center>Powered by:</center></h4>")
589
  with gr.Row():
 
602
  # Main chat area with improved styling
603
  session_state = gr.State({})
604
  stored_messages = gr.State([])
 
 
605
 
606
  chatbot = gr.Chatbot(
607
  label="OpenDeepResearch Assistant",
 
619
 
620
  # Connect clear button
621
  clear_btn.click(
622
+ lambda: ([], [], {"agent": session_state.get("agent")}, []),
623
  None,
624
+ [chatbot, stored_messages, session_state, uploaded_files_state],
625
  )
626
 
627
+ if self.file_upload_folder is not None:
628
+ file_upload.change(
629
+ self.upload_file,
630
+ [file_upload],
631
+ [upload_status, uploaded_files_state],
632
+ )
633
+
634
  # Connect event handlers
635
  self._connect_event_handlers(
636
  text_input,
637
  launch_research_btn,
 
638
  stored_messages,
639
  chatbot,
640
  session_state,
641
+ uploaded_files_state,
642
  )
643
 
644
  return sidebar_demo
 
650
  # Add session state to store session-specific data
651
  session_state = gr.State({})
652
  stored_messages = gr.State([])
653
+ # File input with multiple enabled with the allowed extensions
654
+ file_upload = gr.File(
655
+ label="Upload files for analysis",
656
+ file_types=self.allowed_extensions,
657
+ file_count="multiple",
658
+ )
659
+ uploaded_files_state = gr.State([])
660
 
661
  chatbot = gr.Chatbot(
662
  label="open-Deep-Research",
 
671
 
672
  # If an upload folder is provided, enable the upload feature
673
  if self.file_upload_folder is not None:
 
674
  upload_status = gr.Textbox(
675
  label="Upload Status", interactive=False, visible=False
676
  )
677
+ file_upload.change(
678
  self.upload_file,
679
+ [file_upload],
680
+ [upload_status, uploaded_files_state],
681
  )
682
 
683
  text_input = gr.Textbox(
 
690
  self._connect_event_handlers(
691
  text_input,
692
  launch_research_btn,
 
693
  stored_messages,
694
  chatbot,
695
  session_state,
696
+ uploaded_files_state,
697
  )
698
 
699
  return simple_demo
 
702
  self,
703
  text_input,
704
  launch_research_btn,
 
705
  stored_messages,
706
  chatbot,
707
  session_state,
708
+ uploaded_files_state,
709
  ):
710
  """Connect the event handlers for input elements."""
711
  # Connect text input submit event
712
  text_input.submit(
713
  self.log_user_message,
714
+ [text_input],
715
  [stored_messages, text_input, launch_research_btn],
716
  ).then(
717
  self.interact_with_agent,
718
+ [stored_messages, chatbot, session_state, uploaded_files_state],
719
  [chatbot],
720
  ).then(
721
  lambda: (
 
732
  # Connect button click event
733
  launch_research_btn.click(
734
  self.log_user_message,
735
+ [text_input],
736
  [stored_messages, text_input, launch_research_btn],
737
  ).then(
738
  self.interact_with_agent,
739
+ [stored_messages, chatbot, session_state, uploaded_files_state],
740
  [chatbot],
741
  ).then(
742
  lambda: (