| import platform |
| import subprocess |
|
|
| |
| |
|
|
| |
| if platform.system() == "Linux": |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pass |
|
|
| import gradio as gr |
| from pathlib import Path |
| import logging |
| import mimetypes |
| import shutil |
| import os |
| import traceback |
| import asyncio |
| import tempfile |
| import zipfile |
| from typing import Any, Optional, Dict, List, Union, Tuple |
| from typing import AsyncGenerator |
|
|
| from vms.training_service import TrainingService |
| from vms.captioning_service import CaptioningService |
| from vms.splitting_service import SplittingService |
| from vms.import_service import ImportService |
| from vms.config import ( |
| STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, |
| TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, |
| DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, SMALL_TRAINING_BUCKETS |
| ) |
| from vms.utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time |
| from vms.finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset |
| from vms.training_log_parser import TrainingLogParser |
|
|
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.INFO) |
|
|
| httpx_logger = logging.getLogger('httpx') |
| httpx_logger.setLevel(logging.WARN) |
|
|
|
|
| class VideoTrainerUI: |
| def __init__(self): |
| self.trainer = TrainingService() |
| self.splitter = SplittingService() |
| self.importer = ImportService() |
| self.captioner = CaptioningService() |
| self._should_stop_captioning = False |
| self.log_parser = TrainingLogParser() |
| |
| |
| recovery_result = self.trainer.recover_interrupted_training() |
| |
| self.recovery_status = recovery_result.get("status", "unknown") |
| self.ui_updates = recovery_result.get("ui_updates", {}) |
| |
| if recovery_result["status"] == "recovered": |
| logger.info(f"Training recovery: {recovery_result['message']}") |
| |
| elif recovery_result["status"] == "running": |
| logger.info("Training process is already running") |
| |
| elif recovery_result["status"] in ["error", "idle"]: |
| logger.warning(f"Training status: {recovery_result['message']}") |
| |
| |
|
|
| def initialize_app_state(self): |
| """Initialize all app state in one function to ensure correct output count""" |
| |
| video_list, training_dataset = self.refresh_dataset() |
| |
| |
| button_states = self.get_initial_button_states() |
| start_btn = button_states[0] |
| stop_btn = button_states[1] |
| pause_resume_btn = button_states[2] |
| |
| |
| ui_state = self.load_ui_values() |
| training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]) |
| model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0]) |
| lora_rank_val = ui_state.get("lora_rank", "128") |
| lora_alpha_val = ui_state.get("lora_alpha", "128") |
| num_epochs_val = int(ui_state.get("num_epochs", 70)) |
| batch_size_val = int(ui_state.get("batch_size", 1)) |
| learning_rate_val = float(ui_state.get("learning_rate", 3e-5)) |
| save_iterations_val = int(ui_state.get("save_iterations", 500)) |
| |
| |
| return ( |
| video_list, |
| training_dataset, |
| start_btn, |
| stop_btn, |
| pause_resume_btn, |
| training_preset, |
| model_type_val, |
| lora_rank_val, |
| lora_alpha_val, |
| num_epochs_val, |
| batch_size_val, |
| learning_rate_val, |
| save_iterations_val |
| ) |
|
|
| def initialize_ui_from_state(self): |
| """Initialize UI components from saved state""" |
| ui_state = self.load_ui_values() |
| |
| |
| return ( |
| ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]), |
| ui_state.get("model_type", list(MODEL_TYPES.keys())[0]), |
| ui_state.get("lora_rank", "128"), |
| ui_state.get("lora_alpha", "128"), |
| ui_state.get("num_epochs", 70), |
| ui_state.get("batch_size", 1), |
| ui_state.get("learning_rate", 3e-5), |
| ui_state.get("save_iterations", 500) |
| ) |
|
|
| def update_ui_state(self, **kwargs): |
| """Update UI state with new values""" |
| current_state = self.trainer.load_ui_state() |
| current_state.update(kwargs) |
| self.trainer.save_ui_state(current_state) |
| |
| return None |
|
|
| def load_ui_values(self): |
| """Load UI state values for initializing form fields""" |
| ui_state = self.trainer.load_ui_state() |
| |
| |
| ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70)) |
| ui_state["batch_size"] = int(ui_state.get("batch_size", 1)) |
| ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5)) |
| ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500)) |
| |
| return ui_state |
| |
| def update_captioning_buttons_start(self): |
| """Return individual button values instead of a dictionary""" |
| return ( |
| gr.Button( |
| interactive=False, |
| variant="secondary", |
| ), |
| gr.Button( |
| interactive=True, |
| variant="stop", |
| ), |
| gr.Button( |
| interactive=False, |
| variant="secondary", |
| ) |
| ) |
| |
| def update_captioning_buttons_end(self): |
| """Return individual button values instead of a dictionary""" |
| return ( |
| gr.Button( |
| interactive=True, |
| variant="primary", |
| ), |
| gr.Button( |
| interactive=False, |
| variant="secondary", |
| ), |
| gr.Button( |
| interactive=True, |
| variant="primary", |
| ) |
| ) |
|
|
| |
| def get_initial_button_states(self): |
| """Get the initial states for training buttons based on recovery status""" |
| recovery_result = self.trainer.recover_interrupted_training() |
| ui_updates = recovery_result.get("ui_updates", {}) |
| |
| |
| return ( |
| gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})), |
| gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})), |
| gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"})) |
| ) |
|
|
| def show_refreshing_status(self) -> List[List[str]]: |
| """Show a 'Refreshing...' status in the dataframe""" |
| return [["Refreshing...", "please wait"]] |
|
|
| def stop_captioning(self): |
| """Stop ongoing captioning process and reset UI state""" |
| try: |
| |
| self._should_stop_captioning = True |
| |
| |
| if self.captioner: |
| self.captioner.stop_captioning() |
| |
| |
| updated_list = self.list_training_files_to_caption() |
| |
| |
| return { |
| "training_dataset": gr.update(value=updated_list), |
| "run_autocaption_btn": gr.Button(interactive=True, variant="primary"), |
| "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"), |
| "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary") |
| } |
| except Exception as e: |
| logger.error(f"Error stopping captioning: {str(e)}") |
| return { |
| "training_dataset": gr.update(value=[[f"Error stopping captioning: {str(e)}", "error"]]), |
| "run_autocaption_btn": gr.Button(interactive=True, variant="primary"), |
| "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"), |
| "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary") |
| } |
|
|
| def update_training_ui(self, training_state: Dict[str, Any]): |
| """Update UI components based on training state""" |
| updates = {} |
| |
| |
|
|
| |
| status_text = [] |
| if training_state["status"] != "idle": |
| status_text.extend([ |
| f"Status: {training_state['status']}", |
| f"Progress: {training_state['progress']}", |
| f"Step: {training_state['current_step']}/{training_state['total_steps']}", |
| |
| |
| |
| |
| |
| |
| |
| f"Time elapsed: {training_state['elapsed']}", |
| f"Estimated remaining: {training_state['remaining']}", |
| "", |
| f"Current loss: {training_state['step_loss']}", |
| f"Learning rate: {training_state['learning_rate']}", |
| f"Gradient norm: {training_state['grad_norm']}", |
| f"Memory usage: {training_state['memory']}" |
| ]) |
| |
| if training_state["error_message"]: |
| status_text.append(f"\nError: {training_state['error_message']}") |
| |
| updates["status_box"] = "\n".join(status_text) |
| |
| |
| updates["start_btn"] = gr.Button( |
| "Start training", |
| interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]), |
| variant="primary" if training_state["status"] == "idle" else "secondary" |
| ) |
| |
| updates["stop_btn"] = gr.Button( |
| "Stop training", |
| interactive=(training_state["status"] in ["training", "initializing"]), |
| variant="stop" |
| ) |
| |
| return updates |
| |
| def stop_all_and_clear(self) -> Dict[str, str]: |
| """Stop all running processes and clear data |
| |
| Returns: |
| Dict with status messages for different components |
| """ |
| status_messages = {} |
| |
| try: |
| |
| if self.trainer.is_training_running(): |
| training_result = self.trainer.stop_training() |
| status_messages["training"] = training_result["status"] |
| |
| |
| if self.captioner: |
| self.captioner.stop_captioning() |
| status_messages["captioning"] = "Captioning stopped" |
| |
| |
| if self.splitter.is_processing(): |
| self.splitter.processing = False |
| status_messages["splitting"] = "Scene detection stopped" |
| |
| |
| if self.trainer.file_handler: |
| self.trainer.file_handler.close() |
| logger.removeHandler(self.trainer.file_handler) |
| self.trainer.file_handler = None |
| |
| if LOG_FILE_PATH.exists(): |
| LOG_FILE_PATH.unlink() |
| |
| |
| for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH, |
| MODEL_PATH, OUTPUT_PATH]: |
| if path.exists(): |
| try: |
| shutil.rmtree(path) |
| path.mkdir(parents=True, exist_ok=True) |
| except Exception as e: |
| status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}" |
| else: |
| status_messages[f"clear_{path.name}"] = f"Cleared {path.name}" |
| |
| |
| self._should_stop_captioning = True |
| self.splitter.processing = False |
| |
| |
| self.trainer.setup_logging() |
| |
| return { |
| "status": "All processes stopped and data cleared", |
| "details": status_messages |
| } |
| |
| except Exception as e: |
| return { |
| "status": f"Error during cleanup: {str(e)}", |
| "details": status_messages |
| } |
| |
| def update_titles(self) -> Tuple[Any]: |
| """Update all dynamic titles with current counts |
| |
| Returns: |
| Dict of Gradio updates |
| """ |
| |
| split_videos, _, split_size = count_media_files(VIDEOS_TO_SPLIT_PATH) |
| split_title = format_media_title( |
| "split", split_videos, 0, split_size |
| ) |
| |
| |
| caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH) |
| caption_title = format_media_title( |
| "caption", caption_videos, caption_images, caption_size |
| ) |
| |
| |
| train_videos, train_images, train_size = count_media_files(TRAINING_VIDEOS_PATH) |
| train_title = format_media_title( |
| "train", train_videos, train_images, train_size |
| ) |
| |
| return ( |
| gr.Markdown(value=split_title), |
| gr.Markdown(value=caption_title), |
| gr.Markdown(value=f"{train_title} available for training") |
| ) |
|
|
| def copy_files_to_training_dir(self, prompt_prefix: str): |
| """Run auto-captioning process""" |
|
|
| |
| self._should_stop_captioning = False |
|
|
| try: |
| copy_files_to_training_dir(prompt_prefix) |
|
|
| except Exception as e: |
| traceback.print_exc() |
| raise gr.Error(f"Error copying assets to training dir: {str(e)}") |
|
|
| async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]: |
| """Run auto-captioning process""" |
| try: |
| |
| self._should_stop_captioning = False |
|
|
| |
| yield gr.update( |
| value=[["Starting captioning service...", "initializing"]], |
| headers=["name", "status"] |
| ) |
|
|
| |
| file_statuses = {} |
| |
| |
| async for rows in self.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix): |
| |
| for name, status in rows: |
| file_statuses[name] = status |
| |
| |
| status_rows = [[name, status] for name, status in file_statuses.items()] |
| |
| |
| status_rows.sort(key=lambda x: x[0]) |
| |
| |
| yield gr.update( |
| value=status_rows, |
| headers=["name", "status"] |
| ) |
|
|
| |
| yield gr.update( |
| value=self.list_training_files_to_caption(), |
| headers=["name", "status"] |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Error in captioning: {str(e)}") |
| yield gr.update( |
| value=[[f"Error: {str(e)}", "error"]], |
| headers=["name", "status"] |
| ) |
|
|
| def list_training_files_to_caption(self) -> List[List[str]]: |
| """List all clips and images - both pending and captioned""" |
| files = [] |
| already_listed = {} |
|
|
| |
| for file in STAGING_PATH.glob("*.*"): |
| if is_video_file(file) or is_image_file(file): |
| txt_file = file.with_suffix('.txt') |
| |
| |
| has_caption = txt_file.exists() and txt_file.stat().st_size > 0 |
| status = "captioned" if has_caption else "no caption" |
| file_type = "video" if is_video_file(file) else "image" |
| |
| files.append([file.name, f"{status} ({file_type})", str(file)]) |
| already_listed[file.name] = True |
| |
| |
| for file in TRAINING_VIDEOS_PATH.glob("*.*"): |
| if (is_video_file(file) or is_image_file(file)) and file.name not in already_listed: |
| txt_file = file.with_suffix('.txt') |
| |
| |
| if txt_file.exists() and txt_file.stat().st_size > 0: |
| file_type = "video" if is_video_file(file) else "image" |
| files.append([file.name, f"captioned ({file_type})", str(file)]) |
| already_listed[file.name] = True |
| |
| |
| files.sort(key=lambda x: x[0]) |
| |
| |
| return [[file[0], file[1]] for file in files] |
| |
| def update_training_buttons(self, status: str) -> Dict: |
| """Update training control buttons based on state""" |
| is_training = status in ["training", "initializing"] |
| is_paused = status == "paused" |
| is_completed = status in ["completed", "error", "stopped"] |
| return { |
| "start_btn": gr.Button( |
| interactive=not is_training and not is_paused, |
| variant="primary" if not is_training else "secondary", |
| ), |
| "stop_btn": gr.Button( |
| interactive=is_training or is_paused, |
| variant="stop", |
| ), |
| "pause_resume_btn": gr.Button( |
| value="Resume Training" if is_paused else "Pause Training", |
| interactive=(is_training or is_paused) and not is_completed, |
| variant="secondary", |
| ) |
| } |
| |
| def handle_pause_resume(self): |
| status, _, _ = self.get_latest_status_message_and_logs() |
|
|
| if status == "paused": |
| self.trainer.resume_training() |
| else: |
| self.trainer.pause_training() |
|
|
| return self.get_latest_status_message_logs_and_button_labels() |
|
|
| def handle_stop(self): |
| self.trainer.stop_training() |
| return self.get_latest_status_message_logs_and_button_labels() |
|
|
| def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: |
| """Handle selection of both video clips and images""" |
| try: |
| if not evt: |
| return [ |
| gr.Image( |
| interactive=False, |
| visible=False |
| ), |
| gr.Video( |
| interactive=False, |
| visible=False |
| ), |
| gr.Textbox( |
| visible=False |
| ), |
| None, |
| "No file selected" |
| ] |
| |
| file_name = evt.value |
| if not file_name: |
| return [ |
| gr.Image( |
| interactive=False, |
| visible=False |
| ), |
| gr.Video( |
| interactive=False, |
| visible=False |
| ), |
| gr.Textbox( |
| visible=False |
| ), |
| None, |
| "No file selected" |
| ] |
| |
| |
| possible_paths = [ |
| STAGING_PATH / file_name, |
| |
| |
| |
| |
| |
| |
| |
| ] |
| |
| |
| file_path = None |
| for path in possible_paths: |
| if path.exists(): |
| file_path = path |
| break |
| |
| if not file_path: |
| return [ |
| gr.Image( |
| interactive=False, |
| visible=False |
| ), |
| gr.Video( |
| interactive=False, |
| visible=False |
| ), |
| gr.Textbox( |
| visible=False |
| ), |
| None, |
| f"File not found: {file_name}" |
| ] |
| |
| txt_path = file_path.with_suffix('.txt') |
| caption = txt_path.read_text() if txt_path.exists() else "" |
| |
| |
| if is_video_file(file_path): |
| return [ |
| gr.Image( |
| interactive=False, |
| visible=False |
| ), |
| gr.Video( |
| label="Video Preview", |
| interactive=False, |
| visible=True, |
| value=str(file_path) |
| ), |
| gr.Textbox( |
| label="Caption", |
| lines=6, |
| interactive=True, |
| visible=True, |
| value=str(caption) |
| ), |
| str(file_path), |
| None |
| ] |
| |
| elif is_image_file(file_path): |
| return [ |
| gr.Image( |
| label="Image Preview", |
| interactive=False, |
| visible=True, |
| value=str(file_path) |
| ), |
| gr.Video( |
| interactive=False, |
| visible=False |
| ), |
| gr.Textbox( |
| label="Caption", |
| lines=6, |
| interactive=True, |
| visible=True, |
| value=str(caption) |
| ), |
| str(file_path), |
| None |
| ] |
| else: |
| return [ |
| gr.Image( |
| interactive=False, |
| visible=False |
| ), |
| gr.Video( |
| interactive=False, |
| visible=False |
| ), |
| gr.Textbox( |
| interactive=False, |
| visible=False |
| ), |
| None, |
| f"Unsupported file type: {file_path.suffix}" |
| ] |
| except Exception as e: |
| logger.error(f"Error handling selection: {str(e)}") |
| return [ |
| gr.Image( |
| interactive=False, |
| visible=False |
| ), |
| gr.Video( |
| interactive=False, |
| visible=False |
| ), |
| gr.Textbox( |
| interactive=False, |
| visible=False |
| ), |
| None, |
| f"Error handling selection: {str(e)}" |
| ] |
|
|
| def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, original_file_path: str, prompt_prefix: str): |
| """Save changes to caption""" |
| try: |
| |
| if original_file_path: |
| file_path = Path(original_file_path) |
| self.captioner.update_file_caption(file_path, preview_caption) |
| |
| return gr.update(value="Caption saved successfully!") |
| else: |
| return gr.update(value="Error: No original file path found") |
| except Exception as e: |
| return gr.update(value=f"Error saving caption: {str(e)}") |
|
|
| def get_model_info(self, model_type: str) -> str: |
| """Get information about the selected model type""" |
| if model_type == "hunyuan_video": |
| return """### HunyuanVideo (LoRA) |
| - Required VRAM: ~48GB minimum |
| - Recommended batch size: 1-2 |
| - Typical training time: 2-4 hours |
| - Default resolution: 49x512x768 |
| - Default LoRA rank: 128 (~600 MB)""" |
| |
| elif model_type == "ltx_video": |
| return """### LTX-Video (LoRA) |
| - Required VRAM: ~18GB minimum |
| - Recommended batch size: 1-4 |
| - Typical training time: 1-3 hours |
| - Default resolution: 49x512x768 |
| - Default LoRA rank: 128""" |
| |
| return "" |
|
|
| def get_default_params(self, model_type: str) -> Dict[str, Any]: |
| """Get default training parameters for model type""" |
| if model_type == "hunyuan_video": |
| return { |
| "num_epochs": 70, |
| "batch_size": 1, |
| "learning_rate": 2e-5, |
| "save_iterations": 500, |
| "video_resolution_buckets": SMALL_TRAINING_BUCKETS, |
| "video_reshape_mode": "center", |
| "caption_dropout_p": 0.05, |
| "gradient_accumulation_steps": 1, |
| "rank": 128, |
| "lora_alpha": 128 |
| } |
| else: |
| return { |
| "num_epochs": 70, |
| "batch_size": 1, |
| "learning_rate": 3e-5, |
| "save_iterations": 500, |
| "video_resolution_buckets": SMALL_TRAINING_BUCKETS, |
| "video_reshape_mode": "center", |
| "caption_dropout_p": 0.05, |
| "gradient_accumulation_steps": 4, |
| "rank": 128, |
| "lora_alpha": 128 |
| } |
|
|
| def preview_file(self, selected_text: str) -> Dict: |
| """Generate preview based on selected file |
| |
| Args: |
| selected_text: Text of the selected item containing filename |
| |
| Returns: |
| Dict with preview content for each preview component |
| """ |
| if not selected_text or "Caption:" in selected_text: |
| return { |
| "video": None, |
| "image": None, |
| "text": None |
| } |
| |
| |
| filename = selected_text.split(" (")[0].strip() |
| file_path = TRAINING_VIDEOS_PATH / filename |
| |
| if not file_path.exists(): |
| return { |
| "video": None, |
| "image": None, |
| "text": f"File not found: {filename}" |
| } |
|
|
| |
| mime_type, _ = mimetypes.guess_type(str(file_path)) |
| if not mime_type: |
| return { |
| "video": None, |
| "image": None, |
| "text": f"Unknown file type: {filename}" |
| } |
|
|
| |
| if mime_type.startswith('video/'): |
| return { |
| "video": str(file_path), |
| "image": None, |
| "text": None |
| } |
| elif mime_type.startswith('image/'): |
| return { |
| "video": None, |
| "image": str(file_path), |
| "text": None |
| } |
| elif mime_type.startswith('text/'): |
| try: |
| text_content = file_path.read_text() |
| return { |
| "video": None, |
| "image": None, |
| "text": text_content |
| } |
| except Exception as e: |
| return { |
| "video": None, |
| "image": None, |
| "text": f"Error reading file: {str(e)}" |
| } |
| else: |
| return { |
| "video": None, |
| "image": None, |
| "text": f"Unsupported file type: {mime_type}" |
| } |
|
|
| def list_unprocessed_videos(self) -> gr.Dataframe: |
| """Update list of unprocessed videos""" |
| videos = self.splitter.list_unprocessed_videos() |
| |
| return gr.Dataframe( |
| headers=["name", "status"], |
| value=videos, |
| interactive=False |
| ) |
|
|
| async def start_scene_detection(self, enable_splitting: bool) -> str: |
| """Start background scene detection process |
| |
| Args: |
| enable_splitting: Whether to split videos into scenes |
| """ |
| if self.splitter.is_processing(): |
| return "Scene detection already running" |
| |
| try: |
| await self.splitter.start_processing(enable_splitting) |
| return "Scene detection completed" |
| except Exception as e: |
| return f"Error during scene detection: {str(e)}" |
|
|
|
|
| def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]: |
| state = self.trainer.get_status() |
| logs = self.trainer.get_logs() |
|
|
| |
| if logs: |
| last_state = None |
| for line in logs.splitlines(): |
| state_update = self.log_parser.parse_line(line) |
| if state_update: |
| last_state = state_update |
| |
| if last_state: |
| ui_updates = self.update_training_ui(last_state) |
| state["message"] = ui_updates.get("status_box", state["message"]) |
| |
| |
| if "completed" in state["message"].lower(): |
| state["status"] = "completed" |
|
|
| return (state["status"], state["message"], logs) |
|
|
| def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]: |
| status, message, logs = self.get_latest_status_message_and_logs() |
| return ( |
| message, |
| logs, |
| *self.update_training_buttons(status).values() |
| ) |
|
|
| def get_latest_button_labels(self) -> Tuple[Any, Any, Any]: |
| status, message, logs = self.get_latest_status_message_and_logs() |
| return self.update_training_buttons(status).values() |
| |
| def refresh_dataset(self): |
| """Refresh all dynamic lists and training state""" |
| video_list = self.splitter.list_unprocessed_videos() |
| training_dataset = self.list_training_files_to_caption() |
|
|
| return ( |
| video_list, |
| training_dataset |
| ) |
|
|
| def update_training_params(self, preset_name: str) -> Tuple: |
| """Update UI components based on selected preset""" |
| preset = TRAINING_PRESETS[preset_name] |
| |
| |
| model_display_name = next( |
| key for key, value in MODEL_TYPES.items() |
| if value == preset["model_type"] |
| ) |
| |
| |
| description = preset.get("description", "") |
| |
| |
| buckets = preset["training_buckets"] |
| max_frames = max(frames for frames, _, _ in buckets) |
| max_height = max(height for _, height, _ in buckets) |
| max_width = max(width for _, _, width in buckets) |
| bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution" |
| |
| info_text = f"{description}{bucket_info}" |
| |
| |
| return ( |
| model_display_name, |
| preset["lora_rank"], |
| preset["lora_alpha"], |
| preset["num_epochs"], |
| preset["batch_size"], |
| preset["learning_rate"], |
| preset["save_iterations"], |
| info_text |
| ) |
|
|
| def create_ui(self): |
| """Create Gradio interface""" |
|
|
| with gr.Blocks(title="🎥 Video Model Studio") as app: |
| gr.Markdown("# 🎥 Video Model Studio") |
|
|
| with gr.Tabs() as tabs: |
| with gr.TabItem("1️⃣ Import", id="import_tab"): |
|
|
| with gr.Row(): |
| gr.Markdown("## Automatic splitting and captioning") |
| |
| with gr.Row(): |
| enable_automatic_video_split = gr.Checkbox( |
| label="Automatically split videos into smaller clips", |
| info="Note: a clip is a single camera shot, usually a few seconds", |
| value=True, |
| visible=True |
| ) |
| enable_automatic_content_captioning = gr.Checkbox( |
| label="Automatically caption photos and videos", |
| info="Note: this uses LlaVA and takes some extra time to load and process", |
| value=False, |
| visible=True, |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=3): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("## Import video files") |
| gr.Markdown("You can upload either:") |
| gr.Markdown("- A single MP4 video file") |
| gr.Markdown("- A ZIP archive containing multiple videos and optional caption files") |
| gr.Markdown("For ZIP files: Create a folder containing videos (name is not important) and optional caption files with the same name (eg. `some_video.txt` for `some_video.mp4`)") |
| |
| with gr.Row(): |
| files = gr.Files( |
| label="Upload Images, Videos or ZIP", |
| |
| file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip"], |
| type="filepath" |
| ) |
| |
| with gr.Column(scale=3): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("## Import a YouTube video") |
| gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:") |
|
|
| with gr.Row(): |
| youtube_url = gr.Textbox( |
| label="Import YouTube Video", |
| placeholder="https://www.youtube.com/watch?v=..." |
| ) |
| with gr.Row(): |
| youtube_download_btn = gr.Button("Download YouTube Video", variant="secondary") |
| with gr.Row(): |
| import_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
| with gr.TabItem("2️⃣ Split", id="split_tab"): |
| with gr.Row(): |
| split_title = gr.Markdown("## Splitting of 0 videos (0 bytes)") |
| |
| with gr.Row(): |
| with gr.Column(): |
| detect_btn = gr.Button("Split videos into single-camera shots", variant="primary") |
| detect_status = gr.Textbox(label="Status", interactive=False) |
|
|
| with gr.Column(): |
|
|
| video_list = gr.Dataframe( |
| headers=["name", "status"], |
| label="Videos to split", |
| interactive=False, |
| wrap=True, |
| |
| ) |
| |
| |
| with gr.TabItem("3️⃣ Caption"): |
| with gr.Row(): |
| caption_title = gr.Markdown("## Captioning of 0 files (0 bytes)") |
| |
| with gr.Row(): |
| |
| with gr.Column(): |
| with gr.Row(): |
| custom_prompt_prefix = gr.Textbox( |
| scale=3, |
| label='Prefix to add to ALL captions (eg. "In the style of TOK, ")', |
| placeholder="In the style of TOK, ", |
| lines=2, |
| value=DEFAULT_PROMPT_PREFIX |
| ) |
| captioning_bot_instructions = gr.Textbox( |
| scale=6, |
| label="System instructions for the automatic captioning model", |
| placeholder="Please generate a full description of...", |
| lines=5, |
| value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS |
| ) |
| with gr.Row(): |
| run_autocaption_btn = gr.Button( |
| "Automatically fill missing captions", |
| variant="primary" |
| ) |
| copy_files_to_training_dir_btn = gr.Button( |
| "Copy assets to training directory", |
| variant="primary" |
| ) |
| stop_autocaption_btn = gr.Button( |
| "Stop Captioning", |
| variant="stop", |
| interactive=False |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| training_dataset = gr.Dataframe( |
| headers=["name", "status"], |
| interactive=False, |
| wrap=True, |
| value=self.list_training_files_to_caption(), |
| row_count=10, |
| |
| ) |
|
|
| with gr.Column(): |
| preview_video = gr.Video( |
| label="Video Preview", |
| interactive=False, |
| visible=False |
| ) |
| preview_image = gr.Image( |
| label="Image Preview", |
| interactive=False, |
| visible=False |
| ) |
| preview_caption = gr.Textbox( |
| label="Caption", |
| lines=6, |
| interactive=True |
| ) |
| save_caption_btn = gr.Button("Save Caption") |
| preview_status = gr.Textbox( |
| label="Status", |
| interactive=False, |
| visible=True |
| ) |
|
|
| with gr.TabItem("4️⃣ Train"): |
| with gr.Row(): |
| with gr.Column(): |
|
|
| with gr.Row(): |
| train_title = gr.Markdown("## 0 files available for training (0 bytes)") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| training_preset = gr.Dropdown( |
| choices=list(TRAINING_PRESETS.keys()), |
| label="Training Preset", |
| value=list(TRAINING_PRESETS.keys())[0] |
| ) |
| preset_info = gr.Markdown() |
|
|
| with gr.Row(): |
| with gr.Column(): |
| model_type = gr.Dropdown( |
| choices=list(MODEL_TYPES.keys()), |
| label="Model Type", |
| value=list(MODEL_TYPES.keys())[0] |
| ) |
| model_info = gr.Markdown( |
| value=self.get_model_info(list(MODEL_TYPES.keys())[0]) |
| ) |
|
|
| with gr.Row(): |
| lora_rank = gr.Dropdown( |
| label="LoRA Rank", |
| choices=["16", "32", "64", "128", "256", "512", "1024"], |
| value="128", |
| type="value" |
| ) |
| lora_alpha = gr.Dropdown( |
| label="LoRA Alpha", |
| choices=["16", "32", "64", "128", "256", "512", "1024"], |
| value="128", |
| type="value" |
| ) |
| with gr.Row(): |
| num_epochs = gr.Number( |
| label="Number of Epochs", |
| value=70, |
| minimum=1, |
| precision=0 |
| ) |
| batch_size = gr.Number( |
| label="Batch Size", |
| value=1, |
| minimum=1, |
| precision=0 |
| ) |
| with gr.Row(): |
| learning_rate = gr.Number( |
| label="Learning Rate", |
| value=2e-5, |
| minimum=1e-7 |
| ) |
| save_iterations = gr.Number( |
| label="Save checkpoint every N iterations", |
| value=500, |
| minimum=50, |
| precision=0, |
| info="Model will be saved periodically after these many steps" |
| ) |
| |
| with gr.Column(): |
| with gr.Row(): |
| start_btn = gr.Button( |
| "Start Training", |
| variant="primary", |
| interactive=not ASK_USER_TO_DUPLICATE_SPACE |
| ) |
| pause_resume_btn = gr.Button( |
| "Resume Training", |
| variant="secondary", |
| interactive=False |
| ) |
| stop_btn = gr.Button( |
| "Stop Training", |
| variant="stop", |
| interactive=False |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| status_box = gr.Textbox( |
| label="Training Status", |
| interactive=False, |
| lines=4 |
| ) |
| with gr.Accordion("See training logs"): |
| log_box = gr.TextArea( |
| label="Finetrainers output (see HF Space logs for more details)", |
| interactive=False, |
| lines=40, |
| max_lines=200, |
| autoscroll=True |
| ) |
|
|
| with gr.TabItem("5️⃣ Manage"): |
|
|
| with gr.Column(): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("## Publishing") |
| gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)") |
|
|
| with gr.Row(): |
|
|
| with gr.Column(): |
| repo_id = gr.Textbox( |
| label="HuggingFace Model Repository", |
| placeholder="username/model-name", |
| info="The repository will be created if it doesn't exist" |
| ) |
| gr.Checkbox(label="Check this to make your model public (ie. visible and downloadable by anyone)", info="You model is private by default"), |
| global_stop_btn = gr.Button( |
| "Push my model", |
| |
| ) |
|
|
| |
| with gr.Row(): |
| with gr.Column(): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("## Storage management") |
| with gr.Row(): |
| download_dataset_btn = gr.DownloadButton( |
| "Download dataset", |
| variant="secondary", |
| size="lg" |
| ) |
| download_model_btn = gr.DownloadButton( |
| "Download model", |
| variant="secondary", |
| size="lg" |
| ) |
|
|
|
|
| with gr.Row(): |
| global_stop_btn = gr.Button( |
| "Stop everything and delete my data", |
| variant="stop" |
| ) |
| global_status = gr.Textbox( |
| label="Global Status", |
| interactive=False, |
| visible=False |
| ) |
| |
|
|
| |
| |
| def update_model_info(model): |
| params = self.get_default_params(MODEL_TYPES[model]) |
| info = self.get_model_info(MODEL_TYPES[model]) |
| return { |
| model_info: info, |
| num_epochs: params["num_epochs"], |
| batch_size: params["batch_size"], |
| learning_rate: params["learning_rate"], |
| save_iterations: params["save_iterations"] |
| } |
| |
| def validate_repo(repo_id: str) -> dict: |
| validation = validate_model_repo(repo_id) |
| if validation["error"]: |
| return gr.update(value=repo_id, error=validation["error"]) |
| return gr.update(value=repo_id, error=None) |
| |
| |
|
|
| |
| model_type.change( |
| fn=lambda v: self.update_ui_state(model_type=v), |
| inputs=[model_type], |
| outputs=[] |
| ).then( |
| fn=update_model_info, |
| inputs=[model_type], |
| outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations] |
| ) |
|
|
| |
| lora_rank.change( |
| fn=lambda v: self.update_ui_state(lora_rank=v), |
| inputs=[lora_rank], |
| outputs=[] |
| ) |
|
|
| lora_alpha.change( |
| fn=lambda v: self.update_ui_state(lora_alpha=v), |
| inputs=[lora_alpha], |
| outputs=[] |
| ) |
|
|
| num_epochs.change( |
| fn=lambda v: self.update_ui_state(num_epochs=v), |
| inputs=[num_epochs], |
| outputs=[] |
| ) |
|
|
| batch_size.change( |
| fn=lambda v: self.update_ui_state(batch_size=v), |
| inputs=[batch_size], |
| outputs=[] |
| ) |
|
|
| learning_rate.change( |
| fn=lambda v: self.update_ui_state(learning_rate=v), |
| inputs=[learning_rate], |
| outputs=[] |
| ) |
|
|
| save_iterations.change( |
| fn=lambda v: self.update_ui_state(save_iterations=v), |
| inputs=[save_iterations], |
| outputs=[] |
| ) |
|
|
| async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix): |
| videos = self.list_unprocessed_videos() |
| |
| |
| if videos and not self.splitter.is_processing() and enable_splitting: |
| await self.start_scene_detection(enable_splitting) |
| msg = "Starting automatic scene detection..." |
| else: |
| |
| for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"): |
| await self.splitter.process_video(video_file, enable_splitting=False) |
| msg = "Copying videos without splitting..." |
| |
| copy_files_to_training_dir(prompt_prefix) |
|
|
| |
| if enable_automatic_content_captioning: |
| await self.start_caption_generation( |
| DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, |
| prompt_prefix |
| ) |
| |
| return { |
| tabs: gr.Tabs(selected="split_tab"), |
| video_list: videos, |
| detect_status: msg |
| } |
|
|
|
|
| async def update_titles_after_import(enable_splitting, enable_automatic_content_captioning, prompt_prefix): |
| """Handle post-import updates including titles""" |
| import_result = await on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix) |
| titles = self.update_titles() |
| return (*import_result, *titles) |
|
|
| files.upload( |
| fn=lambda x: self.importer.process_uploaded_files(x), |
| inputs=[files], |
| outputs=[import_status] |
| ).success( |
| fn=update_titles_after_import, |
| inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix], |
| outputs=[ |
| tabs, video_list, detect_status, |
| split_title, caption_title, train_title |
| ] |
| ) |
| |
| youtube_download_btn.click( |
| fn=self.importer.download_youtube_video, |
| inputs=[youtube_url], |
| outputs=[import_status] |
| ).success( |
| fn=on_import_success, |
| inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix], |
| outputs=[tabs, video_list, detect_status] |
| ) |
|
|
| |
| detect_btn.click( |
| fn=self.start_scene_detection, |
| inputs=[enable_automatic_video_split], |
| outputs=[detect_status] |
| ) |
|
|
|
|
| |
| def update_button_states(is_running): |
| return { |
| run_autocaption_btn: gr.Button( |
| interactive=not is_running, |
| variant="secondary" if is_running else "primary", |
| ), |
| stop_autocaption_btn: gr.Button( |
| interactive=is_running, |
| variant="secondary", |
| ), |
| } |
| |
| run_autocaption_btn.click( |
| fn=self.show_refreshing_status, |
| outputs=[training_dataset] |
| ).then( |
| fn=lambda: self.update_captioning_buttons_start(), |
| outputs=[run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn] |
| ).then( |
| fn=self.start_caption_generation, |
| inputs=[captioning_bot_instructions, custom_prompt_prefix], |
| outputs=[training_dataset], |
| ).then( |
| fn=lambda: self.update_captioning_buttons_end(), |
| outputs=[run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn] |
| ) |
| |
| copy_files_to_training_dir_btn.click( |
| fn=self.copy_files_to_training_dir, |
| inputs=[custom_prompt_prefix] |
| ) |
| stop_autocaption_btn.click( |
| fn=self.stop_captioning, |
| outputs=[training_dataset, run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn] |
| ) |
|
|
| original_file_path = gr.State(value=None) |
| training_dataset.select( |
| fn=self.handle_training_dataset_select, |
| outputs=[preview_image, preview_video, preview_caption, original_file_path, preview_status] |
| ) |
|
|
| save_caption_btn.click( |
| fn=self.save_caption_changes, |
| inputs=[preview_caption, preview_image, preview_video, original_file_path, custom_prompt_prefix], |
| outputs=[preview_status] |
| ).success( |
| fn=self.list_training_files_to_caption, |
| outputs=[training_dataset] |
| ) |
|
|
| |
| training_preset.change( |
| fn=lambda v: self.update_ui_state(training_preset=v), |
| inputs=[training_preset], |
| outputs=[] |
| ).then( |
| fn=self.update_training_params, |
| inputs=[training_preset], |
| outputs=[ |
| model_type, lora_rank, lora_alpha, |
| num_epochs, batch_size, learning_rate, |
| save_iterations, preset_info |
| ] |
| ) |
|
|
| |
| start_btn.click( |
| fn=lambda preset, model_type, *args: ( |
| self.log_parser.reset(), |
| self.trainer.start_training( |
| MODEL_TYPES[model_type], |
| *args, |
| preset_name=preset |
| ) |
| ), |
| inputs=[ |
| training_preset, |
| model_type, |
| lora_rank, |
| lora_alpha, |
| num_epochs, |
| batch_size, |
| learning_rate, |
| save_iterations, |
| repo_id |
| ], |
| outputs=[status_box, log_box] |
| ).success( |
| fn=self.get_latest_status_message_logs_and_button_labels, |
| outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn] |
| ) |
|
|
| pause_resume_btn.click( |
| fn=self.handle_pause_resume, |
| outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn] |
| ) |
|
|
| stop_btn.click( |
| fn=self.handle_stop, |
| outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn] |
| ) |
|
|
| def handle_global_stop(): |
| result = self.stop_all_and_clear() |
| |
| status = result["status"] |
| details = "\n".join(f"{k}: {v}" for k, v in result["details"].items()) |
| full_status = f"{status}\n\nDetails:\n{details}" |
| |
| |
| videos = self.splitter.list_unprocessed_videos() |
| clips = self.list_training_files_to_caption() |
| |
| return { |
| global_status: gr.update(value=full_status, visible=True), |
| video_list: videos, |
| training_dataset: clips, |
| status_box: "Training stopped and data cleared", |
| log_box: "", |
| detect_status: "Scene detection stopped", |
| import_status: "All data cleared", |
| preview_status: "Captioning stopped" |
| } |
| |
| download_dataset_btn.click( |
| fn=self.trainer.create_training_dataset_zip, |
| outputs=[download_dataset_btn] |
| ) |
|
|
| download_model_btn.click( |
| fn=self.trainer.get_model_output_safetensors, |
| outputs=[download_model_btn] |
| ) |
|
|
| global_stop_btn.click( |
| fn=handle_global_stop, |
| outputs=[ |
| global_status, |
| video_list, |
| training_dataset, |
| status_box, |
| log_box, |
| detect_status, |
| import_status, |
| preview_status |
| ] |
| ) |
|
|
|
|
| app.load( |
| fn=self.initialize_app_state, |
| outputs=[ |
| video_list, training_dataset, |
| start_btn, stop_btn, pause_resume_btn, |
| training_preset, model_type, lora_rank, lora_alpha, |
| num_epochs, batch_size, learning_rate, save_iterations |
| ] |
| ) |
| |
| |
| timer = gr.Timer(value=1) |
| timer.tick( |
| fn=lambda: ( |
| self.get_latest_status_message_logs_and_button_labels() |
| ), |
| outputs=[ |
| status_box, |
| log_box, |
| start_btn, |
| stop_btn, |
| pause_resume_btn |
| ] |
| ) |
|
|
| timer = gr.Timer(value=5) |
| timer.tick( |
| fn=lambda: ( |
| self.refresh_dataset() |
| ), |
| outputs=[ |
| video_list, training_dataset |
| ] |
| ) |
|
|
| timer = gr.Timer(value=6) |
| timer.tick( |
| fn=lambda: self.update_titles(), |
| outputs=[ |
| split_title, caption_title, train_title |
| ] |
| ) |
|
|
| return app |
|
|
| def create_app(): |
| if ASK_USER_TO_DUPLICATE_SPACE: |
| with gr.Blocks() as app: |
| gr.Markdown("""# Finetrainers UI |
| |
| This Hugging Face space needs to be duplicated to your own billing account to work. |
| |
| Click the 'Duplicate Space' button at the top of the page to create your own copy. |
| |
| It is recommended to use a Nvidia L40S and a persistent storage space. |
| To avoid overpaying for your space, you can configure the auto-sleep settings to fit your personal budget.""") |
| return app |
|
|
| ui = VideoTrainerUI() |
| return ui.create_ui() |
|
|
| if __name__ == "__main__": |
| app = create_app() |
|
|
| allowed_paths = [ |
| str(STORAGE_PATH), |
| str(VIDEOS_TO_SPLIT_PATH), |
| str(STAGING_PATH), |
| str(TRAINING_PATH), |
| str(TRAINING_VIDEOS_PATH), |
| str(MODEL_PATH), |
| str(OUTPUT_PATH) |
| ] |
| app.queue(default_concurrency_limit=1).launch( |
| server_name="0.0.0.0", |
| allowed_paths=allowed_paths |
| ) |