Spaces:
Runtime error
Runtime error
| """ | |
| Hugging Face Hub tab for Video Model Studio UI. | |
| Handles browsing, searching, and importing datasets from the Hugging Face Hub. | |
| """ | |
| import gradio as gr | |
| import logging | |
| import asyncio | |
| import threading | |
| from pathlib import Path | |
| from typing import Dict, Any, List, Optional, Tuple | |
| from vms.utils import BaseTab | |
| logger = logging.getLogger(__name__) | |
| class HubTab(BaseTab): | |
| """Hub tab for importing datasets from Hugging Face Hub""" | |
| def __init__(self, app_state): | |
| super().__init__(app_state) | |
| self.id = "hub_tab" | |
| self.title = "Import from Hugging Face" | |
| self.is_downloading = False | |
| def create(self, parent=None) -> gr.Tab: | |
| """Create the Hub tab UI components""" | |
| with gr.Tab(self.title, id=self.id) as tab: | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.Markdown("## Import a dataset from Hugging Face") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.Markdown("You can use any dataset containing video files (.mp4) with optional captions (same names but in .txt format)") | |
| with gr.Row(): | |
| gr.Markdown("You can also use a dataset containing WebDataset shards (.tar files).") | |
| with gr.Column(): | |
| self.components["dataset_search"] = gr.Textbox( | |
| label="Search Hugging Face Datasets (MP4, WebDataset)", | |
| placeholder="video datasets eg. cakeify, disney, rickroll.." | |
| ) | |
| with gr.Row(): | |
| self.components["dataset_search_btn"] = gr.Button( | |
| "Search Datasets", | |
| variant="primary", | |
| #size="md" | |
| ) | |
| # Dataset browser results section | |
| with gr.Row(visible=False) as dataset_results_row: | |
| self.components["dataset_results_row"] = dataset_results_row | |
| with gr.Column(scale=3): | |
| self.components["dataset_results"] = gr.Dataframe( | |
| headers=["Dataset ID"], # Simplified to show only dataset ID | |
| interactive=False, | |
| wrap=True, | |
| row_count=10, | |
| label="Dataset Results" | |
| ) | |
| with gr.Column(scale=3): | |
| # Dataset info and state | |
| self.components["dataset_info"] = gr.Markdown("Select a dataset to see details") | |
| self.components["dataset_id"] = gr.State(value=None) | |
| self.components["file_type"] = gr.State(value=None) | |
| self.components["download_in_progress"] = gr.State(value=False) | |
| # Files section that appears when a dataset is selected | |
| with gr.Column(visible=False) as files_section: | |
| self.components["files_section"] = files_section | |
| # Video files row (appears if videos are present) | |
| with gr.Row() as video_files_row: | |
| self.components["video_files_row"] = video_files_row | |
| self.components["video_count_text"] = gr.Markdown("Contains 0 video files") | |
| self.components["download_videos_btn"] = gr.Button("Download", variant="primary") | |
| # WebDataset files row (appears if tar files are present) | |
| with gr.Row() as webdataset_files_row: | |
| self.components["webdataset_files_row"] = webdataset_files_row | |
| self.components["webdataset_count_text"] = gr.Markdown("Contains 0 WebDataset (.tar) files") | |
| self.components["download_webdataset_btn"] = gr.Button("Download", variant="primary") | |
| # Status indicator | |
| self.components["status_output"] = gr.Markdown("") | |
| return tab | |
| def connect_events(self) -> None: | |
| """Connect event handlers to UI components""" | |
| # Dataset search event | |
| self.components["dataset_search_btn"].click( | |
| fn=self.search_datasets, | |
| inputs=[self.components["dataset_search"]], | |
| outputs=[ | |
| self.components["dataset_results"], | |
| self.components["dataset_results_row"] | |
| ] | |
| ) | |
| # Dataset selection event | |
| self.components["dataset_results"].select( | |
| fn=self.display_dataset_info, | |
| outputs=[ | |
| self.components["dataset_info"], | |
| self.components["dataset_id"], | |
| self.components["files_section"], | |
| self.components["video_files_row"], | |
| self.components["video_count_text"], | |
| self.components["webdataset_files_row"], | |
| self.components["webdataset_count_text"], | |
| self.components["status_output"] # Reset status output | |
| ] | |
| ) | |
| # Check if we have access to project_tabs_component | |
| if hasattr(self.app, "project_tabs_component"): | |
| tabs_component = self.app.project_tabs_component | |
| else: | |
| # Fallback to prevent errors | |
| logger.warning("project_tabs_component not found in app, using None for tab switching") | |
| tabs_component = None | |
| # Download videos button | |
| self.components["download_videos_btn"].click( | |
| fn=self.set_file_type_and_return, | |
| outputs=[self.components["file_type"]] | |
| ).then( | |
| fn=self.download_file_group, | |
| inputs=[ | |
| self.components["dataset_id"], | |
| self.components["enable_automatic_video_split"], | |
| self.components["file_type"] | |
| ], | |
| outputs=[ | |
| self.components["status_output"], | |
| self.components["import_status"], | |
| self.components["download_videos_btn"], | |
| self.components["download_webdataset_btn"], | |
| self.components["download_in_progress"] | |
| ] | |
| ).success( | |
| fn=self.app.tabs["import_tab"].on_import_success, | |
| inputs=[ | |
| self.components["enable_automatic_video_split"], | |
| self.components["enable_automatic_content_captioning"], | |
| self.app.tabs["caption_tab"].components["custom_prompt_prefix"] | |
| ], | |
| outputs=[ | |
| tabs_component, | |
| self.components["status_output"] | |
| ] | |
| ) | |
| # Download WebDataset button | |
| self.components["download_webdataset_btn"].click( | |
| fn=self.set_file_type_and_return_webdataset, | |
| outputs=[self.components["file_type"]] | |
| ).then( | |
| fn=self.download_file_group, | |
| inputs=[ | |
| self.components["dataset_id"], | |
| self.components["enable_automatic_video_split"], | |
| self.components["file_type"] | |
| ], | |
| outputs=[ | |
| self.components["status_output"], | |
| self.components["import_status"], | |
| self.components["download_videos_btn"], | |
| self.components["download_webdataset_btn"], | |
| self.components["download_in_progress"] | |
| ] | |
| ).success( | |
| fn=self.app.tabs["import_tab"].on_import_success, | |
| inputs=[ | |
| self.components["enable_automatic_video_split"], | |
| self.components["enable_automatic_content_captioning"], | |
| self.app.tabs["caption_tab"].components["custom_prompt_prefix"] | |
| ], | |
| outputs=[ | |
| tabs_component, | |
| self.components["status_output"] | |
| ] | |
| ) | |
| def set_file_type_and_return(self): | |
| """Set file type to video and return it""" | |
| return "video" | |
| def set_file_type_and_return_webdataset(self): | |
| """Set file type to webdataset and return it""" | |
| return "webdataset" | |
| def search_datasets(self, query: str): | |
| """Search datasets on the Hub matching the query""" | |
| try: | |
| logger.info(f"Searching for datasets with query: '{query}'") | |
| results_full = self.app.importing.search_datasets(query) | |
| # Extract just the first column (dataset IDs) for display | |
| results = [[row[0]] for row in results_full] | |
| return results, gr.update(visible=True) | |
| except Exception as e: | |
| logger.error(f"Error searching datasets: {str(e)}", exc_info=True) | |
| return [[f"Error: {str(e)}"]], gr.update(visible=True) | |
| def display_dataset_info(self, evt: gr.SelectData): | |
| """Display detailed information about the selected dataset""" | |
| try: | |
| if not evt or not evt.value: | |
| logger.warning("No dataset selected in display_dataset_info") | |
| return ( | |
| "No dataset selected", # dataset_info | |
| None, # dataset_id | |
| gr.update(visible=False), # files_section | |
| gr.update(visible=False), # video_files_row | |
| "", # video_count_text | |
| gr.update(visible=False), # webdataset_files_row | |
| "", # webdataset_count_text | |
| "" # status_output | |
| ) | |
| # Extract dataset_id from the simplified format | |
| dataset_id = evt.value[0] if isinstance(evt.value, list) else evt.value | |
| logger.info(f"Getting dataset info for: {dataset_id}") | |
| # Use the importer service to get dataset info | |
| info_text, file_counts, _ = self.app.importing.get_dataset_info(dataset_id) | |
| # Get counts of each file type | |
| video_count = file_counts.get("video", 0) | |
| webdataset_count = file_counts.get("webdataset", 0) | |
| # Return all the required outputs individually | |
| return ( | |
| info_text, # dataset_info | |
| dataset_id, # dataset_id | |
| gr.update(visible=True), # files_section | |
| gr.update(visible=video_count > 0), # video_files_row | |
| f"Contains {video_count} video file{'s' if video_count != 1 else ''}", # video_count_text | |
| gr.update(visible=webdataset_count > 0), # webdataset_files_row | |
| f"Contains {webdataset_count} WebDataset (.tar) file{'s' if webdataset_count != 1 else ''}", # webdataset_count_text | |
| "" # status_output | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error displaying dataset info: {str(e)}", exc_info=True) | |
| return ( | |
| f"Error loading dataset information: {str(e)}", # dataset_info | |
| None, # dataset_id | |
| gr.update(visible=False), # files_section | |
| gr.update(visible=False), # video_files_row | |
| "", # video_count_text | |
| gr.update(visible=False), # webdataset_files_row | |
| "", # webdataset_count_text | |
| "" # status_output | |
| ) | |
| async def _download_with_progress(self, dataset_id, file_type, enable_splitting, progress_callback): | |
| """Wrapper for download_file_group that integrates with progress tracking""" | |
| try: | |
| # Set up the progress callback adapter | |
| def progress_adapter(progress_value, desc=None, total=None): | |
| # For a progress bar, we need to convert the values to a 0-1 range | |
| if isinstance(progress_value, (int, float)): | |
| if total is not None and total > 0: | |
| # If we have a total, calculate the fraction | |
| fraction = min(1.0, progress_value / total) | |
| else: | |
| # Otherwise, just use the value directly (assumed to be 0-1) | |
| fraction = min(1.0, progress_value) | |
| # Update the progress with the calculated fraction | |
| progress_callback(fraction, desc=desc) | |
| # Call the actual download function with our adapter | |
| result = await self.app.importing.download_file_group( | |
| dataset_id, | |
| file_type, | |
| enable_splitting, | |
| progress_callback=progress_adapter | |
| ) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error in download with progress: {str(e)}", exc_info=True) | |
| return f"Error: {str(e)}" | |
| def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str, progress=gr.Progress()) -> Tuple: | |
| """Handle download of a group of files (videos or WebDatasets) with progress tracking""" | |
| try: | |
| if not dataset_id: | |
| return ("No dataset selected", | |
| "No dataset selected", | |
| gr.update(), | |
| gr.update(), | |
| False) | |
| logger.info(f"Starting download of {file_type} files from dataset: {dataset_id}") | |
| # Initialize progress tracking | |
| progress(0, desc=f"Starting download of {file_type} files from {dataset_id}") | |
| # Disable download buttons during the process | |
| videos_btn_update = gr.update(interactive=False) | |
| webdataset_btn_update = gr.update(interactive=False) | |
| # Run the download function with progress tracking | |
| # We need to use asyncio.run to run the coroutine in a synchronous context | |
| result = asyncio.run(self._download_with_progress( | |
| dataset_id, | |
| file_type, | |
| enable_splitting, | |
| progress | |
| )) | |
| # When download is complete, update the UI | |
| progress(1.0, desc="Download complete!") | |
| # Create a success message | |
| success_msg = f"✅ Download complete! {result}" | |
| # Update the UI components | |
| return ( | |
| success_msg, # status_output - shows the successful result | |
| result, # import_status | |
| gr.update(interactive=True), # download_videos_btn | |
| gr.update(interactive=True), # download_webdataset_btn | |
| False # download_in_progress | |
| ) | |
| except Exception as e: | |
| error_msg = f"Error downloading {file_type} files: {str(e)}" | |
| logger.error(error_msg, exc_info=True) | |
| return ( | |
| f"❌ Error: {error_msg}", # status_output | |
| error_msg, # import_status | |
| gr.update(interactive=True), # download_videos_btn | |
| gr.update(interactive=True), # download_webdataset_btn | |
| False # download_in_progress | |
| ) |