Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from gradio import update as gr_update | |
| import subprocess | |
| import threading | |
| import time | |
| import re | |
| import os | |
| import random | |
| import tiktoken | |
| import sys | |
| import ffmpeg | |
| from typing import List, Tuple, Optional, Generator, Dict | |
| import json | |
| from gradio import themes | |
| from gradio.themes.utils import colors | |
| import subprocess | |
| from PIL import Image | |
| import math | |
| import cv2 | |
| # Add global stop event | |
| stop_event = threading.Event() | |
| def get_dit_models(dit_folder: str) -> List[str]: | |
| """Get list of available DiT models in the specified folder""" | |
| if not os.path.exists(dit_folder): | |
| return ["mp_rank_00_model_states.pt"] | |
| models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] | |
| models.sort(key=str.lower) | |
| return models if models else ["mp_rank_00_model_states.pt"] | |
| def update_dit_and_lora_dropdowns(dit_folder: str, lora_folder: str, *current_values) -> List[gr.update]: | |
| """Update both DiT and LoRA dropdowns""" | |
| # Get model lists | |
| dit_models = get_dit_models(dit_folder) | |
| lora_choices = get_lora_options(lora_folder) | |
| # Current values processing | |
| dit_value = current_values[0] | |
| if dit_value not in dit_models: | |
| dit_value = dit_models[0] if dit_models else None | |
| weights = current_values[1:5] | |
| multipliers = current_values[5:9] | |
| results = [gr.update(choices=dit_models, value=dit_value)] | |
| # Add LoRA updates | |
| for i in range(4): | |
| weight = weights[i] if i < len(weights) else "None" | |
| multiplier = multipliers[i] if i < len(multipliers) else 1.0 | |
| if weight not in lora_choices: | |
| weight = "None" | |
| results.extend([ | |
| gr.update(choices=lora_choices, value=weight), | |
| gr.update(value=multiplier) | |
| ]) | |
| return results | |
| def extract_video_metadata(video_path: str) -> Dict: | |
| """Extract metadata from video file using ffprobe.""" | |
| cmd = [ | |
| 'ffprobe', | |
| '-v', 'quiet', | |
| '-print_format', 'json', | |
| '-show_format', | |
| video_path | |
| ] | |
| try: | |
| result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) | |
| metadata = json.loads(result.stdout.decode('utf-8')) | |
| if 'format' in metadata and 'tags' in metadata['format']: | |
| comment = metadata['format']['tags'].get('comment', '{}') | |
| return json.loads(comment) | |
| return {} | |
| except Exception as e: | |
| print(f"Metadata extraction failed: {str(e)}") | |
| return {} | |
| def create_parameter_transfer_map(metadata: Dict, target_tab: str) -> Dict: | |
| """Map metadata parameters to Gradio components for different tabs""" | |
| mapping = { | |
| 'common': { | |
| 'prompt': ('prompt', 'v2v_prompt'), | |
| 'width': ('width', 'v2v_width'), | |
| 'height': ('height', 'v2v_height'), | |
| 'batch_size': ('batch_size', 'v2v_batch_size'), | |
| 'video_length': ('video_length', 'v2v_video_length'), | |
| 'fps': ('fps', 'v2v_fps'), | |
| 'infer_steps': ('infer_steps', 'v2v_infer_steps'), | |
| 'seed': ('seed', 'v2v_seed'), | |
| 'model': ('model', 'v2v_model'), | |
| 'vae': ('vae', 'v2v_vae'), | |
| 'te1': ('te1', 'v2v_te1'), | |
| 'te2': ('te2', 'v2v_te2'), | |
| 'save_path': ('save_path', 'v2v_save_path'), | |
| 'flow_shift': ('flow_shift', 'v2v_flow_shift'), | |
| 'cfg_scale': ('cfg_scale', 'v2v_cfg_scale'), | |
| 'output_type': ('output_type', 'v2v_output_type'), | |
| 'attn_mode': ('attn_mode', 'v2v_attn_mode'), | |
| 'block_swap': ('block_swap', 'v2v_block_swap') | |
| }, | |
| 'lora': { | |
| 'lora_weights': [(f'lora{i+1}', f'v2v_lora_weights[{i}]') for i in range(4)], | |
| 'lora_multipliers': [(f'lora{i+1}_multiplier', f'v2v_lora_multipliers[{i}]') for i in range(4)] | |
| } | |
| } | |
| results = {} | |
| for param, value in metadata.items(): | |
| # Handle common parameters | |
| if param in mapping['common']: | |
| target = mapping['common'][param][0 if target_tab == 't2v' else 1] | |
| results[target] = value | |
| # Handle LoRA parameters | |
| if param == 'lora_weights': | |
| for i, weight in enumerate(value[:4]): | |
| target = mapping['lora']['lora_weights'][i][1 if target_tab == 'v2v' else 0] | |
| results[target] = weight | |
| if param == 'lora_multipliers': | |
| for i, mult in enumerate(value[:4]): | |
| target = mapping['lora']['lora_multipliers'][i][1 if target_tab == 'v2v' else 0] | |
| results[target] = float(mult) | |
| return results | |
| def add_metadata_to_video(video_path: str, parameters: dict) -> None: | |
| """Add generation parameters to video metadata using ffmpeg.""" | |
| import json | |
| import subprocess | |
| # Convert parameters to JSON string | |
| params_json = json.dumps(parameters, indent=2) | |
| # Temporary output path | |
| temp_path = video_path.replace(".mp4", "_temp.mp4") | |
| # FFmpeg command to add metadata without re-encoding | |
| cmd = [ | |
| 'ffmpeg', | |
| '-i', video_path, | |
| '-metadata', f'comment={params_json}', | |
| '-codec', 'copy', | |
| temp_path | |
| ] | |
| try: | |
| # Execute FFmpeg command | |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| # Replace original file with the metadata-enhanced version | |
| os.replace(temp_path, video_path) | |
| except subprocess.CalledProcessError as e: | |
| print(f"Failed to add metadata: {e.stderr.decode()}") | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| def count_prompt_tokens(prompt: str) -> int: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| tokens = enc.encode(prompt) | |
| return len(tokens) | |
| def get_lora_options(lora_folder: str = "lora") -> List[str]: | |
| if not os.path.exists(lora_folder): | |
| return ["None"] | |
| lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')] | |
| lora_files.sort(key=str.lower) | |
| return ["None"] + lora_files | |
| def update_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: | |
| new_choices = get_lora_options(lora_folder) | |
| weights = current_values[:4] | |
| multipliers = current_values[4:8] | |
| results = [] | |
| for i in range(4): | |
| weight = weights[i] if i < len(weights) else "None" | |
| multiplier = multipliers[i] if i < len(multipliers) else 1.0 | |
| if weight not in new_choices: | |
| weight = "None" | |
| results.extend([ | |
| gr.update(choices=new_choices, value=weight), | |
| gr.update(value=multiplier) | |
| ]) | |
| return results | |
| def send_to_v2v(evt: gr.SelectData, gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str, int]: | |
| """Transfer selected video and prompt to Video2Video tab""" | |
| if not gallery or evt.index >= len(gallery): | |
| return None, "", selected_index.value | |
| selected_item = gallery[evt.index] | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| # Final cleanup for Gradio Video component | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| # Update the selected index | |
| selected_index.value = evt.index | |
| return str(video_path), prompt, evt.index | |
| def send_selected_to_v2v(gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str]: | |
| """Send the currently selected video to V2V tab""" | |
| if not gallery or selected_index.value is None or selected_index.value >= len(gallery): | |
| return None, "" | |
| selected_item = gallery[selected_index.value] | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| # Final cleanup for Gradio Video component | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| return str(video_path), prompt | |
| def clear_cuda_cache(): | |
| """Clear CUDA cache if available""" | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Optional: synchronize to ensure cache is cleared | |
| torch.cuda.synchronize() | |
| def process_single_video( | |
| prompt: str, | |
| width: int, | |
| height: int, | |
| batch_size: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| dit_folder: str, | |
| model: str, | |
| vae: str, | |
| te1: str, | |
| te2: str, | |
| save_path: str, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| output_type: str, | |
| attn_mode: str, | |
| block_swap: int, | |
| exclude_single_blocks: bool, | |
| use_split_attn: bool, | |
| lora_folder: str, | |
| lora1: str = "", | |
| lora2: str = "", | |
| lora3: str = "", | |
| lora4: str = "", | |
| lora1_multiplier: float = 1.0, | |
| lora2_multiplier: float = 1.0, | |
| lora3_multiplier: float = 1.0, | |
| lora4_multiplier: float = 1.0, | |
| video_path: Optional[str] = None, | |
| image_path: Optional[str] = None, | |
| strength: Optional[float] = None, | |
| negative_prompt: Optional[str] = None, | |
| embedded_cfg_scale: Optional[float] = None, | |
| split_uncond: Optional[bool] = None, | |
| guidance_scale: Optional[float] = None, | |
| use_fp8: bool = True | |
| ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
| """Generate a single video with the given parameters""" | |
| global stop_event | |
| if stop_event.is_set(): | |
| yield [], "", "" | |
| return | |
| # Determine if this is a SkyReels model and what type | |
| is_skyreels = "skyreels" in model.lower() | |
| is_skyreels_i2v = is_skyreels and "i2v" in model.lower() | |
| is_skyreels_t2v = is_skyreels and "t2v" in model.lower() | |
| if is_skyreels: | |
| # Force certain parameters for SkyReels | |
| if negative_prompt is None: | |
| negative_prompt = "" | |
| if embedded_cfg_scale is None: | |
| embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels | |
| if split_uncond is None: | |
| split_uncond = True | |
| if guidance_scale is None: | |
| guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided | |
| # Determine the input channels based on model type | |
| if is_skyreels_i2v: | |
| dit_in_channels = 32 # SkyReels I2V uses 32 channels | |
| else: | |
| dit_in_channels = 16 # SkyReels T2V uses 16 channels (same as regular models) | |
| else: | |
| dit_in_channels = 16 # Regular Hunyuan models use 16 channels | |
| embedded_cfg_scale = cfg_scale | |
| if os.path.isabs(model): | |
| model_path = model | |
| else: | |
| model_path = os.path.normpath(os.path.join(dit_folder, model)) | |
| env = os.environ.copy() | |
| env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| env["BATCH_RUN_ID"] = f"{time.time()}" | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| else: | |
| batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) | |
| if batch_size > 1: # Only modify seed for batch generation | |
| current_seed = (seed + batch_id * 100003) % (2**32) | |
| else: | |
| current_seed = seed | |
| clear_cuda_cache() | |
| command = [ | |
| sys.executable, | |
| "hv_generate_video.py", | |
| "--dit", model_path, | |
| "--vae", vae, | |
| "--text_encoder1", te1, | |
| "--text_encoder2", te2, | |
| "--prompt", prompt, | |
| "--video_size", str(height), str(width), | |
| "--video_length", str(video_length), | |
| "--fps", str(fps), | |
| "--infer_steps", str(infer_steps), | |
| "--save_path", save_path, | |
| "--seed", str(current_seed), | |
| "--flow_shift", str(flow_shift), | |
| "--embedded_cfg_scale", str(cfg_scale), | |
| "--output_type", output_type, | |
| "--attn_mode", attn_mode, | |
| "--blocks_to_swap", str(block_swap), | |
| "--fp8_llm", | |
| "--vae_chunk_size", "32", | |
| "--vae_spatial_tile_sample_min_size", "128" | |
| ] | |
| if use_fp8: | |
| command.append("--fp8") | |
| # Add negative prompt and embedded cfg scale for SkyReels | |
| if is_skyreels: | |
| command.extend(["--dit_in_channels", str(dit_in_channels)]) | |
| command.extend(["--guidance_scale", str(guidance_scale)]) | |
| if negative_prompt: | |
| command.extend(["--negative_prompt", negative_prompt]) | |
| if split_uncond: | |
| command.append("--split_uncond") | |
| # Add LoRA weights and multipliers if provided | |
| valid_loras = [] | |
| for weight, mult in zip([lora1, lora2, lora3, lora4], | |
| [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): | |
| if weight and weight != "None": | |
| valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
| if valid_loras: | |
| weights = [weight for weight, _ in valid_loras] | |
| multipliers = [str(mult) for _, mult in valid_loras] | |
| command.extend(["--lora_weight"] + weights) | |
| command.extend(["--lora_multiplier"] + multipliers) | |
| if exclude_single_blocks: | |
| command.append("--exclude_single_blocks") | |
| if use_split_attn: | |
| command.append("--split_attn") | |
| # Handle input paths | |
| if video_path: | |
| command.extend(["--video_path", video_path]) | |
| if strength is not None: | |
| command.extend(["--strength", str(strength)]) | |
| elif image_path: | |
| command.extend(["--image_path", image_path]) | |
| # Only add strength parameter for non-SkyReels I2V models | |
| # SkyReels I2V doesn't use strength parameter for image-to-video generation | |
| if strength is not None and not is_skyreels_i2v: | |
| command.extend(["--strength", str(strength)]) | |
| print(f"{command}") | |
| p = subprocess.Popen( | |
| command, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| env=env, | |
| text=True, | |
| encoding='utf-8', | |
| errors='replace', | |
| bufsize=1 | |
| ) | |
| videos = [] | |
| while True: | |
| if stop_event.is_set(): | |
| p.terminate() | |
| p.wait() | |
| yield [], "", "Generation stopped by user." | |
| return | |
| line = p.stdout.readline() | |
| if not line: | |
| if p.poll() is not None: | |
| break | |
| continue | |
| print(line, end='') | |
| if '|' in line and '%' in line and '[' in line and ']' in line: | |
| yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() | |
| p.stdout.close() | |
| p.wait() | |
| clear_cuda_cache() | |
| time.sleep(0.5) | |
| # Collect generated video | |
| save_path_abs = os.path.abspath(save_path) | |
| if os.path.exists(save_path_abs): | |
| all_videos = sorted( | |
| [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
| key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
| reverse=True | |
| ) | |
| matching_videos = [v for v in all_videos if f"_{current_seed}" in v] | |
| if matching_videos: | |
| video_path = os.path.join(save_path_abs, matching_videos[0]) | |
| # Collect parameters for metadata | |
| parameters = { | |
| "prompt": prompt, | |
| "width": width, | |
| "height": height, | |
| "video_length": video_length, | |
| "fps": fps, | |
| "infer_steps": infer_steps, | |
| "seed": current_seed, | |
| "model": model, | |
| "vae": vae, | |
| "te1": te1, | |
| "te2": te2, | |
| "save_path": save_path, | |
| "flow_shift": flow_shift, | |
| "cfg_scale": cfg_scale, | |
| "output_type": output_type, | |
| "attn_mode": attn_mode, | |
| "block_swap": block_swap, | |
| "lora_weights": [lora1, lora2, lora3, lora4], | |
| "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], | |
| "input_video": video_path if video_path else None, | |
| "input_image": image_path if image_path else None, | |
| "strength": strength, | |
| "negative_prompt": negative_prompt if is_skyreels else None, | |
| "embedded_cfg_scale": embedded_cfg_scale if is_skyreels else None | |
| } | |
| add_metadata_to_video(video_path, parameters) | |
| videos.append((str(video_path), f"Seed: {current_seed}")) | |
| yield videos, f"Completed (seed: {current_seed})", "" | |
| # The issue is in the process_batch function, in the section that handles different input types | |
| # Here's the corrected version of that section: | |
| def process_batch( | |
| prompt: str, | |
| width: int, | |
| height: int, | |
| batch_size: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| dit_folder: str, | |
| model: str, | |
| vae: str, | |
| te1: str, | |
| te2: str, | |
| save_path: str, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| output_type: str, | |
| attn_mode: str, | |
| block_swap: int, | |
| exclude_single_blocks: bool, | |
| use_split_attn: bool, | |
| lora_folder: str, | |
| *args | |
| ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
| """Process a batch of videos using Gradio's queue""" | |
| global stop_event | |
| stop_event.clear() | |
| all_videos = [] | |
| progress_text = "Starting generation..." | |
| yield [], "Preparing...", progress_text | |
| # Extract additional arguments | |
| num_lora_weights = 4 | |
| lora_weights = args[:num_lora_weights] | |
| lora_multipliers = args[num_lora_weights:num_lora_weights*2] | |
| extra_args = args[num_lora_weights*2:] | |
| # Determine if this is a SkyReels model and what type | |
| is_skyreels = "skyreels" in model.lower() | |
| is_skyreels_i2v = is_skyreels and "i2v" in model.lower() | |
| is_skyreels_t2v = is_skyreels and "t2v" in model.lower() | |
| # Handle input paths and additional parameters | |
| input_path = extra_args[0] if extra_args else None | |
| strength = float(extra_args[1]) if len(extra_args) > 1 else None | |
| # Get use_fp8 flag (it should be the last parameter) | |
| use_fp8 = bool(extra_args[-1]) if extra_args and len(extra_args) >= 3 else True | |
| # Get SkyReels specific parameters if applicable | |
| if is_skyreels: | |
| # Always set embedded_cfg_scale to 1.0 for SkyReels models | |
| embedded_cfg_scale = 1.0 | |
| negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else "" | |
| # Use cfg_scale for guidance_scale parameter | |
| guidance_scale = float(extra_args[3]) if len(extra_args) > 3 and extra_args[3] is not None else cfg_scale | |
| split_uncond = True if len(extra_args) > 4 and extra_args[4] else False | |
| else: | |
| negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else None | |
| guidance_scale = cfg_scale | |
| embedded_cfg_scale = cfg_scale | |
| split_uncond = bool(extra_args[4]) if len(extra_args) > 4 else None | |
| for i in range(batch_size): | |
| if stop_event.is_set(): | |
| break | |
| batch_text = f"Generating video {i + 1} of {batch_size}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Handle different input types | |
| video_path = None | |
| image_path = None | |
| if input_path: | |
| # Check if it's an image file (common image extensions) | |
| is_image = False | |
| lower_path = input_path.lower() | |
| image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') | |
| is_image = any(lower_path.endswith(ext) for ext in image_extensions) | |
| # Only use image_path for SkyReels I2V models and actual image files | |
| if is_skyreels_i2v and is_image: | |
| image_path = input_path | |
| else: | |
| video_path = input_path | |
| # Prepare arguments for process_single_video | |
| single_video_args = [ | |
| prompt, width, height, batch_size, video_length, fps, infer_steps, | |
| seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder | |
| ] | |
| single_video_args.extend(lora_weights) | |
| single_video_args.extend(lora_multipliers) | |
| single_video_args.extend([video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) | |
| for videos, status, progress in process_single_video(*single_video_args): | |
| if videos: | |
| all_videos.extend(videos) | |
| yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
| yield all_videos, "Batch complete", "" | |
| def update_wanx_image_dimensions(image): | |
| """Update dimensions from uploaded image""" | |
| if image is None: | |
| return "", gr.update(value=832), gr.update(value=480) | |
| img = Image.open(image) | |
| w, h = img.size | |
| w = (w // 32) * 32 | |
| h = (h // 32) * 32 | |
| return f"{w}x{h}", w, h | |
| def calculate_wanx_width(height, original_dims): | |
| """Calculate width based on height maintaining aspect ratio""" | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_width = math.floor((height * aspect_ratio) / 32) * 32 | |
| return gr.update(value=new_width) | |
| def calculate_wanx_height(width, original_dims): | |
| """Calculate height based on width maintaining aspect ratio""" | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_height = math.floor((width / aspect_ratio) / 32) * 32 | |
| return gr.update(value=new_height) | |
| def update_wanx_from_scale(scale, original_dims): | |
| """Update dimensions based on scale percentage""" | |
| if not original_dims: | |
| return gr.update(), gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| new_w = math.floor((orig_w * scale / 100) / 32) * 32 | |
| new_h = math.floor((orig_h * scale / 100) / 32) * 32 | |
| return gr.update(value=new_w), gr.update(value=new_h) | |
| def recommend_wanx_flow_shift(width, height): | |
| """Get recommended flow shift value based on dimensions""" | |
| recommended_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 | |
| return gr.update(value=recommended_shift) | |
| def handle_wanx_gallery_select(evt: gr.SelectData) -> int: | |
| """Track selected index when gallery item is clicked""" | |
| return evt.index | |
| def wanx_generate_video( | |
| prompt, | |
| negative_prompt, | |
| input_image, | |
| width, | |
| height, | |
| video_length, | |
| fps, | |
| infer_steps, | |
| flow_shift, | |
| guidance_scale, | |
| seed, | |
| task, | |
| dit_path, | |
| vae_path, | |
| t5_path, | |
| clip_path, | |
| save_path, | |
| output_type, | |
| sample_solver, | |
| attn_mode, | |
| block_swap, | |
| fp8, | |
| fp8_t5 | |
| ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
| """Generate video with WanX model (supports both i2v and t2v)""" | |
| global stop_event | |
| if stop_event.is_set(): | |
| yield [], "", "" | |
| return | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| else: | |
| current_seed = seed | |
| # Check if we need input image (required for i2v, not for t2v) | |
| if "i2v" in task and not input_image: | |
| yield [], "Error: No input image provided", "Please provide an input image for image-to-video generation" | |
| return | |
| # Prepare environment | |
| env = os.environ.copy() | |
| env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| clear_cuda_cache() | |
| command = [ | |
| sys.executable, | |
| "wan_generate_video.py", | |
| "--task", task, | |
| "--prompt", prompt, | |
| "--video_size", str(height), str(width), | |
| "--video_length", str(video_length), | |
| "--fps", str(fps), | |
| "--infer_steps", str(infer_steps), | |
| "--save_path", save_path, | |
| "--seed", str(current_seed), | |
| "--flow_shift", str(flow_shift), | |
| "--guidance_scale", str(guidance_scale), | |
| "--output_type", output_type, | |
| "--attn_mode", attn_mode, | |
| "--blocks_to_swap", str(block_swap), | |
| "--dit", dit_path, | |
| "--vae", vae_path, | |
| "--t5", t5_path, | |
| "--sample_solver", sample_solver | |
| ] | |
| # Add image path only for i2v task and if input image is provided | |
| if "i2v" in task and input_image: | |
| command.extend(["--image_path", input_image]) | |
| command.extend(["--clip", clip_path]) # CLIP is only needed for i2v | |
| if negative_prompt: | |
| command.extend(["--negative_prompt", negative_prompt]) | |
| if fp8: | |
| command.append("--fp8") | |
| if fp8_t5: | |
| command.append("--fp8_t5") | |
| print(f"Running: {' '.join(command)}") | |
| p = subprocess.Popen( | |
| command, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| env=env, | |
| text=True, | |
| encoding='utf-8', | |
| errors='replace', | |
| bufsize=1 | |
| ) | |
| videos = [] | |
| while True: | |
| if stop_event.is_set(): | |
| p.terminate() | |
| p.wait() | |
| yield [], "", "Generation stopped by user." | |
| return | |
| line = p.stdout.readline() | |
| if not line: | |
| if p.poll() is not None: | |
| break | |
| continue | |
| print(line, end='') | |
| if '|' in line and '%' in line and '[' in line and ']' in line: | |
| yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() | |
| p.stdout.close() | |
| p.wait() | |
| clear_cuda_cache() | |
| time.sleep(0.5) | |
| # Collect generated video | |
| save_path_abs = os.path.abspath(save_path) | |
| if os.path.exists(save_path_abs): | |
| all_videos = sorted( | |
| [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
| key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
| reverse=True | |
| ) | |
| matching_videos = [v for v in all_videos if f"_{current_seed}" in v] | |
| if matching_videos: | |
| video_path = os.path.join(save_path_abs, matching_videos[0]) | |
| # Collect parameters for metadata | |
| parameters = { | |
| "prompt": prompt, | |
| "width": width, | |
| "height": height, | |
| "video_length": video_length, | |
| "fps": fps, | |
| "infer_steps": infer_steps, | |
| "seed": current_seed, | |
| "task": task, | |
| "flow_shift": flow_shift, | |
| "guidance_scale": guidance_scale, | |
| "output_type": output_type, | |
| "attn_mode": attn_mode, | |
| "block_swap": block_swap, | |
| "input_image": input_image if "i2v" in task else None | |
| } | |
| add_metadata_to_video(video_path, parameters) | |
| videos.append((str(video_path), f"Seed: {current_seed}")) | |
| yield videos, f"Completed (seed: {current_seed})", "" | |
| def send_wanx_to_v2v( | |
| gallery: list, | |
| prompt: str, | |
| selected_index: int, | |
| width: int, | |
| height: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| flow_shift: float, | |
| guidance_scale: float, | |
| negative_prompt: str | |
| ) -> Tuple: | |
| """Send the selected WanX video to Video2Video tab""" | |
| if not gallery or selected_index is None or selected_index >= len(gallery): | |
| return (None, "", width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, guidance_scale, negative_prompt) | |
| selected_item = gallery[selected_index] | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, guidance_scale, negative_prompt) | |
| def wanx_generate_video_batch( | |
| prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| video_length, | |
| fps, | |
| infer_steps, | |
| flow_shift, | |
| guidance_scale, | |
| seed, | |
| task, | |
| dit_path, | |
| vae_path, | |
| t5_path, | |
| clip_path, | |
| save_path, | |
| output_type, | |
| sample_solver, | |
| attn_mode, | |
| block_swap, | |
| fp8, | |
| fp8_t5, | |
| batch_size=1, | |
| input_image=None, # Optional for i2v | |
| lora_folder=None, | |
| *args | |
| ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
| """Generate videos with WanX with support for batches and LoRA""" | |
| global stop_event | |
| stop_event.clear() | |
| all_videos = [] | |
| progress_text = "Starting generation..." | |
| yield [], "Preparing...", progress_text | |
| # Extract LoRA parameters from args | |
| num_loras = 4 # Fixed number of LoRA inputs | |
| lora_weights = args[:num_loras] | |
| lora_multipliers = args[num_loras:num_loras*2] | |
| exclude_single_blocks = args[num_loras*2] if len(args) > num_loras*2 else False | |
| # Process each item in the batch | |
| for i in range(batch_size): | |
| if stop_event.is_set(): | |
| yield all_videos, "Generation stopped by user", "" | |
| return | |
| # Calculate seed for this batch item | |
| current_seed = seed | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| elif batch_size > 1: | |
| current_seed = seed + i | |
| batch_text = f"Generating video {i + 1} of {batch_size}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Prepare command | |
| env = os.environ.copy() | |
| env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| command = [ | |
| sys.executable, | |
| "wan_generate_video.py", | |
| "--task", task, | |
| "--prompt", prompt, | |
| "--video_size", str(height), str(width), | |
| "--video_length", str(video_length), | |
| "--fps", str(fps), | |
| "--infer_steps", str(infer_steps), | |
| "--save_path", save_path, | |
| "--seed", str(current_seed), | |
| "--flow_shift", str(flow_shift), | |
| "--guidance_scale", str(guidance_scale), | |
| "--output_type", output_type, | |
| "--attn_mode", attn_mode, | |
| "--dit", dit_path, | |
| "--vae", vae_path, | |
| "--t5", t5_path, | |
| "--sample_solver", sample_solver | |
| ] | |
| # Add image path if provided (for i2v) | |
| if input_image and "i2v" in task: | |
| command.extend(["--image_path", input_image]) | |
| command.extend(["--clip", clip_path]) # CLIP is needed for i2v | |
| # Add negative prompt if provided | |
| if negative_prompt: | |
| command.extend(["--negative_prompt", negative_prompt]) | |
| # Add block swap if provided | |
| if block_swap > 0: | |
| command.extend(["--blocks_to_swap", str(block_swap)]) | |
| # Add fp8 flags if enabled | |
| if fp8: | |
| command.append("--fp8") | |
| if fp8_t5: | |
| command.append("--fp8_t5") | |
| # Add LoRA parameters | |
| valid_loras = [] | |
| for j, (weight, mult) in enumerate(zip(lora_weights, lora_multipliers)): | |
| if weight and weight != "None": | |
| valid_loras.append((os.path.join(lora_folder, weight), float(mult))) | |
| if valid_loras: | |
| weights = [weight for weight, _ in valid_loras] | |
| multipliers = [str(mult) for _, mult in valid_loras] | |
| command.extend(["--lora_weight"] + weights) | |
| command.extend(["--lora_multiplier"] + multipliers) | |
| # Add LoRA options | |
| if exclude_single_blocks: | |
| command.append("--exclude_single_blocks") | |
| print(f"Running: {' '.join(command)}") | |
| # Execute command | |
| p = subprocess.Popen( | |
| command, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| env=env, | |
| text=True, | |
| encoding='utf-8', | |
| errors='replace', | |
| bufsize=1 | |
| ) | |
| videos = [] | |
| # Process output | |
| while True: | |
| if stop_event.is_set(): | |
| p.terminate() | |
| p.wait() | |
| yield all_videos, "Generation stopped by user", "" | |
| return | |
| line = p.stdout.readline() | |
| if not line: | |
| if p.poll() is not None: | |
| break | |
| continue | |
| print(line, end='') | |
| if '|' in line and '%' in line and '[' in line and ']' in line: | |
| yield all_videos.copy(), f"Batch {i+1}/{batch_size}: Processing (seed: {current_seed})", line.strip() | |
| p.stdout.close() | |
| p.wait() | |
| # Clean CUDA cache | |
| clear_cuda_cache() | |
| time.sleep(0.5) | |
| # Collect generated video | |
| save_path_abs = os.path.abspath(save_path) | |
| if os.path.exists(save_path_abs): | |
| all_video_files = sorted( | |
| [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
| key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
| reverse=True | |
| ) | |
| matching_videos = [v for v in all_video_files if f"_{current_seed}" in v] | |
| if matching_videos: | |
| video_path = os.path.join(save_path_abs, matching_videos[0]) | |
| videos.append((str(video_path), f"Seed: {current_seed}")) | |
| all_videos.extend(videos) | |
| yield all_videos, "Batch complete", "" | |
| def update_wanx_t2v_dimensions(size): | |
| """Update width and height based on selected size""" | |
| width, height = map(int, size.split('*')) | |
| return gr.update(value=width), gr.update(value=height) | |
| def handle_wanx_t2v_gallery_select(evt: gr.SelectData) -> int: | |
| """Track selected index when gallery item is clicked""" | |
| return evt.index | |
| def send_wanx_t2v_to_v2v( | |
| gallery, prompt, selected_index, width, height, video_length, | |
| fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt | |
| ) -> Tuple: | |
| """Send the selected WanX T2V video to Video2Video tab""" | |
| if not gallery or selected_index is None or selected_index >= len(gallery): | |
| return (None, "", width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, guidance_scale, negative_prompt) | |
| selected_item = gallery[selected_index] | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, guidance_scale, negative_prompt) | |
| # UI setup | |
| with gr.Blocks( | |
| theme=themes.Default( | |
| primary_hue=colors.Color( | |
| name="custom", | |
| c50="#E6F0FF", | |
| c100="#CCE0FF", | |
| c200="#99C1FF", | |
| c300="#66A3FF", | |
| c400="#3384FF", | |
| c500="#0060df", # This is your main color | |
| c600="#0052C2", | |
| c700="#003D91", | |
| c800="#002961", | |
| c900="#001430", | |
| c950="#000A18" | |
| ) | |
| ), | |
| css=""" | |
| .gallery-item:first-child { border: 2px solid #4CAF50 !important; } | |
| .gallery-item:first-child:hover { border-color: #45a049 !important; } | |
| .green-btn { | |
| background: linear-gradient(to bottom right, #2ecc71, #27ae60) !important; | |
| color: white !important; | |
| border: none !important; | |
| } | |
| .green-btn:hover { | |
| background: linear-gradient(to bottom right, #27ae60, #219651) !important; | |
| } | |
| .refresh-btn { | |
| max-width: 40px !important; | |
| min-width: 40px !important; | |
| height: 40px !important; | |
| border-radius: 50% !important; | |
| padding: 0 !important; | |
| display: flex !important; | |
| align-items: center !important; | |
| justify-content: center !important; | |
| } | |
| """, | |
| ) as demo: | |
| # Add state for tracking selected video indices in both tabs | |
| selected_index = gr.State(value=None) # For Text to Video | |
| v2v_selected_index = gr.State(value=None) # For Video to Video | |
| params_state = gr.State() #New addition | |
| i2v_selected_index = gr.State(value=None) | |
| skyreels_selected_index = gr.State(value=None) | |
| demo.load(None, None, None, js=""" | |
| () => { | |
| document.title = 'H1111'; | |
| function updateTitle(text) { | |
| if (text && text.trim()) { | |
| const progressMatch = text.match(/(\d+)%.*\[.*<(\d+:\d+),/); | |
| if (progressMatch) { | |
| const percentage = progressMatch[1]; | |
| const timeRemaining = progressMatch[2]; | |
| document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; | |
| } | |
| } | |
| } | |
| setTimeout(() => { | |
| const progressElements = document.querySelectorAll('textarea.scroll-hide'); | |
| progressElements.forEach(element => { | |
| if (element) { | |
| new MutationObserver(() => { | |
| updateTitle(element.value); | |
| }).observe(element, { | |
| attributes: true, | |
| childList: true, | |
| characterData: true | |
| }); | |
| } | |
| }); | |
| }, 1000); | |
| } | |
| """) | |
| with gr.Tabs() as tabs: | |
| # Text to Video Tab | |
| with gr.Tab(id=1, label="Text to Video"): | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) | |
| with gr.Column(scale=1): | |
| token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| t2v_width = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Width") | |
| t2v_height = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Height") | |
| video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25, elem_id="my_special_slider") | |
| fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24, elem_id="my_special_slider") | |
| infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30, elem_id="my_special_slider") | |
| flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0, elem_id="my_special_slider") | |
| cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg Scale", value=7.0, elem_id="my_special_slider") | |
| with gr.Column(): | |
| with gr.Row(): | |
| video_output = gr.Gallery( | |
| label="Generated Videos (Click to select)", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_label=True, | |
| elem_id="gallery", | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| with gr.Row():send_t2v_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
| with gr.Row(): | |
| refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| lora_weights = [] | |
| lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
| model = gr.Dropdown( | |
| label="DiT Model", | |
| choices=get_dit_models("hunyuan"), | |
| value="mp_rank_00_model_states.pt", | |
| allow_custom_value=True, | |
| interactive=True | |
| ) | |
| vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
| te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
| te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
| save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
| use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
| attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
| #Image to Video Tab | |
| with gr.Tab(label="Image to Video") as i2v_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| i2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) | |
| with gr.Column(scale=1): | |
| i2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| i2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| i2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| i2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| i2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| i2v_stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| i2v_input = gr.Image(label="Input Image", type="filepath") | |
| i2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") | |
| # Scale slider as percentage | |
| scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
| original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
| # Width and height inputs | |
| with gr.Row(): | |
| width = gr.Number(label="New Width", value=544, step=16) | |
| calc_height_btn = gr.Button("→") | |
| calc_width_btn = gr.Button("←") | |
| height = gr.Number(label="New Height", value=544, step=16) | |
| i2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) | |
| i2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) | |
| i2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
| i2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) | |
| i2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) | |
| with gr.Column(): | |
| i2v_output = gr.Gallery( | |
| label="Generated Videos (Click to select)", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_label=True, | |
| elem_id="gallery", | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| i2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
| # Add LoRA section for Image2Video | |
| i2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| i2v_lora_weights = [] | |
| i2v_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| i2v_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| i2v_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| i2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| i2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| i2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
| i2v_model = gr.Dropdown( | |
| label="DiT Model", | |
| choices=get_dit_models("hunyuan"), | |
| value="mp_rank_00_model_states.pt", | |
| allow_custom_value=True, | |
| interactive=True | |
| ) | |
| i2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
| i2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
| i2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
| i2v_save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| i2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| i2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| i2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
| i2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
| i2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| i2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
| # Video to Video Tab | |
| with gr.Tab(id=2, label="Video to Video") as v2v_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| v2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) | |
| v2v_negative_prompt = gr.Textbox( | |
| scale=3, | |
| label="Negative Prompt (for SkyReels models)", | |
| value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", | |
| lines=3 | |
| ) | |
| with gr.Column(scale=1): | |
| v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| v2v_stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| v2v_input = gr.Video(label="Input Video", format="mp4") | |
| v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") | |
| v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
| v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
| # Width and Height Inputs | |
| with gr.Row(): | |
| v2v_width = gr.Number(label="New Width", value=544, step=16) | |
| v2v_calc_height_btn = gr.Button("→") | |
| v2v_calc_width_btn = gr.Button("←") | |
| v2v_height = gr.Number(label="New Height", value=544, step=16) | |
| v2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) | |
| v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) | |
| v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
| v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) | |
| v2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) | |
| with gr.Column(): | |
| v2v_output = gr.Gallery( | |
| label="Generated Videos", | |
| columns=[1], | |
| rows=[1], | |
| object_fit="contain", | |
| height="auto" | |
| ) | |
| v2v_send_to_input_btn = gr.Button("Send Selected to Input") # New button | |
| v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| v2v_lora_weights = [] | |
| v2v_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| v2v_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| v2v_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
| v2v_model = gr.Dropdown( | |
| label="DiT Model", | |
| choices=get_dit_models("hunyuan"), | |
| value="mp_rank_00_model_states.pt", | |
| allow_custom_value=True, | |
| interactive=True | |
| ) | |
| v2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
| v2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
| v2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
| v2v_save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| v2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
| v2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
| v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| v2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
| v2v_split_uncond = gr.Checkbox(label="Split Unconditional (for SkyReels)", value=True) | |
| with gr.Tab(label="SkyReels-i2v") as skyreels_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| skyreels_prompt = gr.Textbox( | |
| scale=3, | |
| label="Enter your prompt", | |
| value="A person walking on a beach at sunset", | |
| lines=5 | |
| ) | |
| skyreels_negative_prompt = gr.Textbox( | |
| scale=3, | |
| label="Negative Prompt", | |
| value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", | |
| lines=3 | |
| ) | |
| with gr.Column(scale=1): | |
| skyreels_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| skyreels_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| skyreels_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| skyreels_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| skyreels_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| skyreels_stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| skyreels_input = gr.Image(label="Input Image (optional)", type="filepath") | |
| skyreels_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") | |
| # Scale slider as percentage | |
| skyreels_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
| skyreels_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
| # Width and height inputs | |
| with gr.Row(): | |
| skyreels_width = gr.Number(label="New Width", value=544, step=16) | |
| skyreels_calc_height_btn = gr.Button("→") | |
| skyreels_calc_width_btn = gr.Button("←") | |
| skyreels_height = gr.Number(label="New Height", value=544, step=16) | |
| skyreels_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) | |
| skyreels_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) | |
| skyreels_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
| skyreels_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) | |
| skyreels_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=6.0) | |
| skyreels_embedded_cfg_scale = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, label="Embedded CFG Scale", value=1.0) | |
| with gr.Column(): | |
| skyreels_output = gr.Gallery( | |
| label="Generated Videos (Click to select)", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_label=True, | |
| elem_id="gallery", | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| skyreels_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
| # Add LoRA section for SKYREELS | |
| skyreels_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| skyreels_lora_weights = [] | |
| skyreels_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| skyreels_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| skyreels_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| skyreels_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| skyreels_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| skyreels_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
| skyreels_model = gr.Dropdown( | |
| label="DiT Model", | |
| choices=get_dit_models("skyreels"), | |
| value="skyreels_hunyuan_i2v_bf16.safetensors", | |
| allow_custom_value=True, | |
| interactive=True | |
| ) | |
| skyreels_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
| skyreels_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
| skyreels_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
| skyreels_save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| skyreels_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| skyreels_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| skyreels_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
| skyreels_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
| skyreels_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| skyreels_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
| skyreels_split_uncond = gr.Checkbox(label="Split Unconditional", value=True) | |
| # WanX Image to Video Tab | |
| with gr.Tab(label="WanX-i2v") as wanx_i2v_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| wanx_prompt = gr.Textbox( | |
| scale=3, | |
| label="Enter your prompt", | |
| value="A person walking on a beach at sunset", | |
| lines=5 | |
| ) | |
| wanx_negative_prompt = gr.Textbox( | |
| scale=3, | |
| label="Negative Prompt", | |
| value="", | |
| lines=3, | |
| info="Leave empty to use default negative prompt" | |
| ) | |
| with gr.Column(scale=1): | |
| wanx_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| wanx_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| wanx_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| wanx_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| wanx_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| wanx_stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| wanx_input = gr.Image(label="Input Image", type="filepath") | |
| wanx_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
| wanx_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
| # Width and height display | |
| with gr.Row(): | |
| wanx_width = gr.Number(label="Width", value=832, interactive=True) | |
| wanx_calc_height_btn = gr.Button("→") | |
| wanx_calc_width_btn = gr.Button("←") | |
| wanx_height = gr.Number(label="Height", value=480, interactive=True) | |
| wanx_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") | |
| wanx_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) | |
| wanx_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) | |
| wanx_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) | |
| wanx_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=3.0, | |
| info="Recommended: 3.0 for 480p, 5.0 for others") | |
| wanx_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) | |
| with gr.Column(): | |
| wanx_output = gr.Gallery( | |
| label="Generated Videos (Click to select)", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_label=True, | |
| elem_id="gallery", | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| wanx_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
| with gr.Row(): | |
| wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| wanx_lora_weights = [] | |
| wanx_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| wanx_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| wanx_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| wanx_task = gr.Dropdown( | |
| label="Task", | |
| choices=["i2v-14B"], | |
| value="i2v-14B", | |
| info="Currently only i2v-14B is supported" | |
| ) | |
| wanx_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_i2v_480p_14B_bf16.safetensors") | |
| wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") | |
| wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") | |
| wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") | |
| wanx_save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") | |
| wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) | |
| wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) | |
| wanx_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) | |
| wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| #WanX-t2v Tab | |
| # WanX Text to Video Tab | |
| with gr.Tab(label="WanX-t2v") as wanx_t2v_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| wanx_t2v_prompt = gr.Textbox( | |
| scale=3, | |
| label="Enter your prompt", | |
| value="A person walking on a beach at sunset", | |
| lines=5 | |
| ) | |
| wanx_t2v_negative_prompt = gr.Textbox( | |
| scale=3, | |
| label="Negative Prompt", | |
| value="", | |
| lines=3, | |
| info="Leave empty to use default negative prompt" | |
| ) | |
| with gr.Column(scale=1): | |
| wanx_t2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| wanx_t2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| wanx_t2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| wanx_t2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| wanx_t2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| wanx_t2v_stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| wanx_t2v_width = gr.Number(label="Width", value=832, interactive=True, info="Should be divisible by 32") | |
| wanx_t2v_height = gr.Number(label="Height", value=480, interactive=True, info="Should be divisible by 32") | |
| wanx_t2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") | |
| wanx_t2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) | |
| wanx_t2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) | |
| wanx_t2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) | |
| wanx_t2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, | |
| info="Recommended: 3.0 for I2V with 480p, 5.0 for others") | |
| wanx_t2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) | |
| with gr.Column(): | |
| wanx_t2v_output = gr.Gallery( | |
| label="Generated Videos (Click to select)", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_label=True, | |
| elem_id="gallery", | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
| with gr.Row(): | |
| wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| wanx_t2v_lora_weights = [] | |
| wanx_t2v_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| wanx_t2v_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| wanx_t2v_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| wanx_t2v_task = gr.Dropdown( | |
| label="Task", | |
| choices=["t2v-1.3B", "t2v-14B", "t2i-14B"], | |
| value="t2v-14B", | |
| info="Select model size: t2v-1.3B is faster, t2v-14B has higher quality" | |
| ) | |
| wanx_t2v_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_t2v_14B_bf16.safetensors") | |
| wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") | |
| wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") | |
| wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") | |
| wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") | |
| wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, | |
| info="Max 39 for 14B model, 29 for 1.3B model") | |
| wanx_t2v_fp8 = gr.Checkbox(label="Use FP8", value=True) | |
| wanx_t2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) | |
| wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| #Video Info Tab | |
| with gr.Tab("Video Info") as video_info_tab: | |
| with gr.Row(): | |
| video_input = gr.Video(label="Upload Video", interactive=True) | |
| metadata_output = gr.JSON(label="Generation Parameters") | |
| with gr.Row(): | |
| send_to_t2v_btn = gr.Button("Send to Text2Video", variant="primary") | |
| send_to_v2v_btn = gr.Button("Send to Video2Video", variant="primary") | |
| with gr.Row(): | |
| status = gr.Textbox(label="Status", interactive=False) | |
| #Merge Model's tab | |
| with gr.Tab("Convert LoRA") as convert_lora_tab: | |
| def suggest_output_name(file_obj) -> str: | |
| """Generate suggested output name from input file""" | |
| if not file_obj: | |
| return "" | |
| # Get input filename without extension and add MUSUBI | |
| base_name = os.path.splitext(os.path.basename(file_obj.name))[0] | |
| return f"{base_name}_MUSUBI" | |
| def convert_lora(input_file, output_name: str, target_format: str) -> str: | |
| """Convert LoRA file to specified format""" | |
| try: | |
| if not input_file: | |
| return "Error: No input file selected" | |
| # Ensure output directory exists | |
| os.makedirs("lora", exist_ok=True) | |
| # Construct output path | |
| output_path = os.path.join("lora", f"{output_name}.safetensors") | |
| # Build command | |
| cmd = [ | |
| sys.executable, | |
| "convert_lora.py", | |
| "--input", input_file.name, | |
| "--output", output_path, | |
| "--target", target_format | |
| ] | |
| print(f"Converting {input_file.name} to {output_path}") | |
| # Execute conversion | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| check=True | |
| ) | |
| if os.path.exists(output_path): | |
| return f"Successfully converted LoRA to {output_path}" | |
| else: | |
| return "Error: Output file not created" | |
| except subprocess.CalledProcessError as e: | |
| return f"Error during conversion: {e.stderr}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| with gr.Row(): | |
| input_file = gr.File(label="Input LoRA File", file_types=[".safetensors"]) | |
| output_name = gr.Textbox(label="Output Name", placeholder="Output filename (without extension)") | |
| format_radio = gr.Radio( | |
| choices=["default", "other"], | |
| value="default", | |
| label="Target Format", | |
| info="Choose 'default' for H1111/MUSUBI format or 'other' for diffusion pipe format" | |
| ) | |
| with gr.Row(): | |
| convert_btn = gr.Button("Convert LoRA", variant="primary") | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| # Automatically update output name when file is selected | |
| input_file.change( | |
| fn=suggest_output_name, | |
| inputs=[input_file], | |
| outputs=[output_name] | |
| ) | |
| # Handle conversion | |
| convert_btn.click( | |
| fn=convert_lora, | |
| inputs=[input_file, output_name, format_radio], | |
| outputs=status_output | |
| ) | |
| with gr.Tab("Model Merging") as model_merge_tab: | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Model selection | |
| dit_model = gr.Dropdown( | |
| label="Base DiT Model", | |
| choices=["mp_rank_00_model_states.pt"], | |
| value="mp_rank_00_model_states.pt", | |
| allow_custom_value=True, | |
| interactive=True | |
| ) | |
| merge_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Output model name | |
| output_model = gr.Textbox(label="Output Model Name", value="merged_model.safetensors") | |
| exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| merge_btn = gr.Button("Merge Models", variant="primary") | |
| merge_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Row(): | |
| # LoRA selection section (similar to Text2Video) | |
| merge_lora_weights = [] | |
| merge_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| merge_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| merge_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| merge_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
| #text to video | |
| def change_to_tab_one(): | |
| return gr.Tabs(selected=1) #This will navigate | |
| #video to video | |
| def change_to_tab_two(): | |
| return gr.Tabs(selected=2) #This will navigate | |
| def change_to_skyreels_tab(): | |
| return gr.Tabs(selected=3) | |
| #SKYREELS TAB!!! | |
| # Add state management for dimensions | |
| def sync_skyreels_dimensions(width, height): | |
| return gr.update(value=width), gr.update(value=height) | |
| # Add this function to update the LoRA dropdowns in the SKYREELS tab | |
| def update_skyreels_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: | |
| new_choices = get_lora_options(lora_folder) | |
| weights = current_values[:4] | |
| multipliers = current_values[4:8] | |
| results = [] | |
| for i in range(4): | |
| weight = weights[i] if i < len(weights) else "None" | |
| multiplier = multipliers[i] if i < len(multipliers) else 1.0 | |
| if weight not in new_choices: | |
| weight = "None" | |
| results.extend([ | |
| gr.update(choices=new_choices, value=weight), | |
| gr.update(value=multiplier) | |
| ]) | |
| return results | |
| # Add this function to update the models dropdown in the SKYREELS tab | |
| def update_skyreels_model_dropdown(dit_folder: str) -> Dict: | |
| models = get_dit_models(dit_folder) | |
| return gr.update(choices=models, value=models[0] if models else None) | |
| # Add event handler for model dropdown refresh | |
| skyreels_dit_folder.change( | |
| fn=update_skyreels_model_dropdown, | |
| inputs=[skyreels_dit_folder], | |
| outputs=[skyreels_model] | |
| ) | |
| # Add handlers for the refresh button | |
| skyreels_refresh_btn.click( | |
| fn=update_skyreels_lora_dropdowns, | |
| inputs=[skyreels_lora_folder] + skyreels_lora_weights + skyreels_lora_multipliers, | |
| outputs=[drop for _ in range(4) for drop in [skyreels_lora_weights[_], skyreels_lora_multipliers[_]]] | |
| ) | |
| # Skyreels dimension handling | |
| def calculate_skyreels_width(height, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_width = math.floor((height * aspect_ratio) / 16) * 16 | |
| return gr.update(value=new_width) | |
| def calculate_skyreels_height(width, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_height = math.floor((width / aspect_ratio) / 16) * 16 | |
| return gr.update(value=new_height) | |
| def update_skyreels_from_scale(scale, original_dims): | |
| if not original_dims: | |
| return gr.update(), gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| new_w = math.floor((orig_w * scale / 100) / 16) * 16 | |
| new_h = math.floor((orig_h * scale / 100) / 16) * 16 | |
| return gr.update(value=new_w), gr.update(value=new_h) | |
| def update_skyreels_dimensions(image): | |
| if image is None: | |
| return "", gr.update(value=544), gr.update(value=544) | |
| img = Image.open(image) | |
| w, h = img.size | |
| w = (w // 16) * 16 | |
| h = (h // 16) * 16 | |
| return f"{w}x{h}", w, h | |
| def handle_skyreels_gallery_select(evt: gr.SelectData) -> int: | |
| return evt.index | |
| def send_skyreels_to_v2v( | |
| gallery: list, | |
| prompt: str, | |
| selected_index: int, | |
| width: int, | |
| height: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| lora1: str, | |
| lora2: str, | |
| lora3: str, | |
| lora4: str, | |
| lora1_multiplier: float, | |
| lora2_multiplier: float, | |
| lora3_multiplier: float, | |
| lora4_multiplier: float, | |
| negative_prompt: str = "" # Add this parameter | |
| ) -> Tuple: | |
| if not gallery or selected_index is None or selected_index >= len(gallery): | |
| return (None, "", width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, cfg_scale, lora1, lora2, lora3, lora4, | |
| lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, | |
| negative_prompt) # Add negative_prompt to return | |
| selected_item = gallery[selected_index] | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, cfg_scale, lora1, lora2, lora3, lora4, | |
| lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, | |
| negative_prompt) # Add negative_prompt to return | |
| # Add event handlers for the SKYREELS tab | |
| skyreels_prompt.change(fn=count_prompt_tokens, inputs=skyreels_prompt, outputs=skyreels_token_counter) | |
| skyreels_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| # Image input handling | |
| skyreels_input.change( | |
| fn=update_skyreels_dimensions, | |
| inputs=[skyreels_input], | |
| outputs=[skyreels_original_dims, skyreels_width, skyreels_height] | |
| ) | |
| skyreels_scale_slider.change( | |
| fn=update_skyreels_from_scale, | |
| inputs=[skyreels_scale_slider, skyreels_original_dims], | |
| outputs=[skyreels_width, skyreels_height] | |
| ) | |
| skyreels_calc_width_btn.click( | |
| fn=calculate_skyreels_width, | |
| inputs=[skyreels_height, skyreels_original_dims], | |
| outputs=[skyreels_width] | |
| ) | |
| skyreels_calc_height_btn.click( | |
| fn=calculate_skyreels_height, | |
| inputs=[skyreels_width, skyreels_original_dims], | |
| outputs=[skyreels_height] | |
| ) | |
| # SKYREELS tab generator button handler | |
| skyreels_generate_btn.click( | |
| fn=process_batch, | |
| inputs=[ | |
| skyreels_prompt, | |
| skyreels_width, | |
| skyreels_height, | |
| skyreels_batch_size, | |
| skyreels_video_length, | |
| skyreels_fps, | |
| skyreels_infer_steps, | |
| skyreels_seed, | |
| skyreels_dit_folder, | |
| skyreels_model, | |
| skyreels_vae, | |
| skyreels_te1, | |
| skyreels_te2, | |
| skyreels_save_path, | |
| skyreels_flow_shift, | |
| skyreels_embedded_cfg_scale, | |
| skyreels_output_type, | |
| skyreels_attn_mode, | |
| skyreels_block_swap, | |
| skyreels_exclude_single_blocks, | |
| skyreels_use_split_attn, | |
| skyreels_lora_folder, | |
| *skyreels_lora_weights, | |
| *skyreels_lora_multipliers, | |
| skyreels_input, | |
| skyreels_strength, | |
| skyreels_negative_prompt, | |
| skyreels_guidance_scale, | |
| skyreels_split_uncond, | |
| skyreels_use_fp8 | |
| ], | |
| outputs=[skyreels_output, skyreels_batch_progress, skyreels_progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[skyreels_batch_size], | |
| outputs=skyreels_selected_index | |
| ) | |
| # Gallery selection handling | |
| skyreels_output.select( | |
| fn=handle_skyreels_gallery_select, | |
| outputs=skyreels_selected_index | |
| ) | |
| # Send to Video2Video handler | |
| skyreels_send_to_v2v_btn.click( | |
| fn=send_skyreels_to_v2v, | |
| inputs=[ | |
| skyreels_output, skyreels_prompt, skyreels_selected_index, | |
| skyreels_width, skyreels_height, skyreels_video_length, | |
| skyreels_fps, skyreels_infer_steps, skyreels_seed, | |
| skyreels_flow_shift, skyreels_guidance_scale | |
| ] + skyreels_lora_weights + skyreels_lora_multipliers + [skyreels_negative_prompt], # This is ok because skyreels_negative_prompt is a Gradio component | |
| outputs=[ | |
| v2v_input, v2v_prompt, v2v_width, v2v_height, | |
| v2v_video_length, v2v_fps, v2v_infer_steps, | |
| v2v_seed, v2v_flow_shift, v2v_cfg_scale | |
| ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] | |
| ).then( | |
| fn=change_to_tab_two, | |
| inputs=None, | |
| outputs=[tabs] | |
| ) | |
| # Refresh button handler | |
| skyreels_refresh_outputs = [skyreels_model] | |
| for i in range(4): | |
| skyreels_refresh_outputs.extend([skyreels_lora_weights[i], skyreels_lora_multipliers[i]]) | |
| skyreels_refresh_btn.click( | |
| fn=update_dit_and_lora_dropdowns, | |
| inputs=[skyreels_dit_folder, skyreels_lora_folder, skyreels_model] + skyreels_lora_weights + skyreels_lora_multipliers, | |
| outputs=skyreels_refresh_outputs | |
| ) | |
| # Add skyreels_selected_index to the initial states at the beginning of the script | |
| skyreels_selected_index = gr.State(value=None) # Add this with other state declarations | |
| def calculate_v2v_width(height, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_width = math.floor((height * aspect_ratio) / 16) * 16 # Ensure divisible by 16 | |
| return gr.update(value=new_width) | |
| def calculate_v2v_height(width, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_height = math.floor((width / aspect_ratio) / 16) * 16 # Ensure divisible by 16 | |
| return gr.update(value=new_height) | |
| def update_v2v_from_scale(scale, original_dims): | |
| if not original_dims: | |
| return gr.update(), gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Ensure divisible by 16 | |
| new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Ensure divisible by 16 | |
| return gr.update(value=new_w), gr.update(value=new_h) | |
| def update_v2v_dimensions(video): | |
| if video is None: | |
| return "", gr.update(value=544), gr.update(value=544) | |
| cap = cv2.VideoCapture(video) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| cap.release() | |
| # Make dimensions divisible by 16 | |
| w = (w // 16) * 16 | |
| h = (h // 16) * 16 | |
| return f"{w}x{h}", w, h | |
| # Event Handlers for Video to Video Tab | |
| v2v_input.change( | |
| fn=update_v2v_dimensions, | |
| inputs=[v2v_input], | |
| outputs=[v2v_original_dims, v2v_width, v2v_height] | |
| ) | |
| v2v_scale_slider.change( | |
| fn=update_v2v_from_scale, | |
| inputs=[v2v_scale_slider, v2v_original_dims], | |
| outputs=[v2v_width, v2v_height] | |
| ) | |
| v2v_calc_width_btn.click( | |
| fn=calculate_v2v_width, | |
| inputs=[v2v_height, v2v_original_dims], | |
| outputs=[v2v_width] | |
| ) | |
| v2v_calc_height_btn.click( | |
| fn=calculate_v2v_height, | |
| inputs=[v2v_width, v2v_original_dims], | |
| outputs=[v2v_height] | |
| ) | |
| ##Image 2 video dimension logic | |
| def calculate_width(height, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_width = math.floor((height * aspect_ratio) / 16) * 16 # Changed from 8 to 16 | |
| return gr.update(value=new_width) | |
| def calculate_height(width, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_height = math.floor((width / aspect_ratio) / 16) * 16 # Changed from 8 to 16 | |
| return gr.update(value=new_height) | |
| def update_from_scale(scale, original_dims): | |
| if not original_dims: | |
| return gr.update(), gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Changed from 8 to 16 | |
| new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Changed from 8 to 16 | |
| return gr.update(value=new_w), gr.update(value=new_h) | |
| def update_dimensions(image): | |
| if image is None: | |
| return "", gr.update(value=544), gr.update(value=544) | |
| img = Image.open(image) | |
| w, h = img.size | |
| # Make dimensions divisible by 16 | |
| w = (w // 16) * 16 # Changed from 8 to 16 | |
| h = (h // 16) * 16 # Changed from 8 to 16 | |
| return f"{w}x{h}", w, h | |
| i2v_input.change( | |
| fn=update_dimensions, | |
| inputs=[i2v_input], | |
| outputs=[original_dims, width, height] | |
| ) | |
| scale_slider.change( | |
| fn=update_from_scale, | |
| inputs=[scale_slider, original_dims], | |
| outputs=[width, height] | |
| ) | |
| calc_width_btn.click( | |
| fn=calculate_width, | |
| inputs=[height, original_dims], | |
| outputs=[width] | |
| ) | |
| calc_height_btn.click( | |
| fn=calculate_height, | |
| inputs=[width, original_dims], | |
| outputs=[height] | |
| ) | |
| # Function to get available DiT models | |
| def get_dit_models(dit_folder: str) -> List[str]: | |
| if not os.path.exists(dit_folder): | |
| return ["mp_rank_00_model_states.pt"] | |
| models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] | |
| models.sort(key=str.lower) | |
| return models if models else ["mp_rank_00_model_states.pt"] | |
| # Function to perform model merging | |
| def merge_models( | |
| dit_folder: str, | |
| dit_model: str, | |
| output_model: str, | |
| exclude_single_blocks: bool, | |
| merge_lora_folder: str, | |
| *lora_params # Will contain both weights and multipliers | |
| ) -> str: | |
| try: | |
| # Separate weights and multipliers | |
| num_loras = len(lora_params) // 2 | |
| weights = list(lora_params[:num_loras]) | |
| multipliers = list(lora_params[num_loras:]) | |
| # Filter out "None" selections | |
| valid_loras = [] | |
| for weight, mult in zip(weights, multipliers): | |
| if weight and weight != "None": | |
| valid_loras.append((os.path.join(merge_lora_folder, weight), mult)) | |
| if not valid_loras: | |
| return "No LoRA models selected for merging" | |
| # Create output path in the dit folder | |
| os.makedirs(dit_folder, exist_ok=True) | |
| output_path = os.path.join(dit_folder, output_model) | |
| # Prepare command | |
| cmd = [ | |
| sys.executable, | |
| "merge_lora.py", | |
| "--dit", os.path.join(dit_folder, dit_model), | |
| "--save_merged_model", output_path | |
| ] | |
| # Add LoRA weights and multipliers | |
| weights = [weight for weight, _ in valid_loras] | |
| multipliers = [str(mult) for _, mult in valid_loras] | |
| cmd.extend(["--lora_weight"] + weights) | |
| cmd.extend(["--lora_multiplier"] + multipliers) | |
| if exclude_single_blocks: | |
| cmd.append("--exclude_single_blocks") | |
| # Execute merge operation | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| check=True | |
| ) | |
| if os.path.exists(output_path): | |
| return f"Successfully merged model and saved to {output_path}" | |
| else: | |
| return "Error: Output file not created" | |
| except subprocess.CalledProcessError as e: | |
| return f"Error during merging: {e.stderr}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Update DiT model dropdown | |
| def update_dit_dropdown(dit_folder: str) -> Dict: | |
| models = get_dit_models(dit_folder) | |
| return gr.update(choices=models, value=models[0] if models else None) | |
| # Connect events | |
| merge_btn.click( | |
| fn=merge_models, | |
| inputs=[ | |
| dit_folder, | |
| dit_model, | |
| output_model, | |
| exclude_single_blocks, | |
| merge_lora_folder, | |
| *merge_lora_weights, | |
| *merge_lora_multipliers | |
| ], | |
| outputs=merge_status | |
| ) | |
| # Refresh buttons for both DiT and LoRA dropdowns | |
| merge_refresh_btn.click( | |
| fn=lambda f: update_dit_dropdown(f), | |
| inputs=[dit_folder], | |
| outputs=[dit_model] | |
| ) | |
| # LoRA refresh handling | |
| merge_refresh_outputs = [] | |
| for i in range(4): | |
| merge_refresh_outputs.extend([merge_lora_weights[i], merge_lora_multipliers[i]]) | |
| merge_refresh_btn.click( | |
| fn=update_lora_dropdowns, | |
| inputs=[merge_lora_folder] + merge_lora_weights + merge_lora_multipliers, | |
| outputs=merge_refresh_outputs | |
| ) | |
| # Event handlers | |
| prompt.change(fn=count_prompt_tokens, inputs=prompt, outputs=token_counter) | |
| v2v_prompt.change(fn=count_prompt_tokens, inputs=v2v_prompt, outputs=v2v_token_counter) | |
| stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| #Image_to_Video | |
| def image_to_video(image_path, output_path, width, height, frames=240): # Add width, height parameters | |
| img = Image.open(image_path) | |
| # Resize to the specified dimensions | |
| img_resized = img.resize((width, height), Image.LANCZOS) | |
| temp_image_path = os.path.join(os.path.dirname(output_path), "temp_resized_image.png") | |
| img_resized.save(temp_image_path) | |
| # Rest of function remains the same | |
| frame_rate = 24 | |
| duration = frames / frame_rate | |
| command = [ | |
| "ffmpeg", "-loop", "1", "-i", temp_image_path, "-c:v", "libx264", | |
| "-t", str(duration), "-pix_fmt", "yuv420p", | |
| "-vf", f"fps={frame_rate}", output_path | |
| ] | |
| try: | |
| subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| print(f"Video saved to {output_path}") | |
| return True | |
| except subprocess.CalledProcessError as e: | |
| print(f"An error occurred while creating the video: {e}") | |
| return False | |
| finally: | |
| # Clean up the temporary image file | |
| if os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| img.close() # Make sure to close the image file explicitly | |
| def generate_from_image( | |
| image_path, | |
| prompt, width, height, video_length, fps, infer_steps, | |
| seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder, strength, batch_size, *lora_params | |
| ): | |
| """Generate video from input image with progressive updates""" | |
| global stop_event | |
| stop_event.clear() | |
| # Create temporary video path | |
| temp_video_path = os.path.join(save_path, f"temp_{os.path.basename(image_path)}.mp4") | |
| try: | |
| # Convert image to video | |
| if not image_to_video(image_path, temp_video_path, width, height, frames=video_length): | |
| yield [], "Failed to create temporary video", "Error in video creation" | |
| return | |
| # Ensure video is fully written before proceeding | |
| time.sleep(1) | |
| if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: | |
| yield [], "Failed to create temporary video", "Temporary video file is empty or missing" | |
| return | |
| # Get video dimensions | |
| try: | |
| probe = ffmpeg.probe(temp_video_path) | |
| video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) | |
| if video_stream is None: | |
| raise ValueError("No video stream found") | |
| width = int(video_stream['width']) | |
| height = int(video_stream['height']) | |
| except Exception as e: | |
| yield [], f"Error reading video dimensions: {str(e)}", "Video processing error" | |
| return | |
| # Generate the video using the temporary file | |
| try: | |
| generator = process_single_video( | |
| prompt, width, height, batch_size, video_length, fps, infer_steps, | |
| seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder, *lora_params, video_path=temp_video_path, strength=strength | |
| ) | |
| # Forward all generator updates | |
| for videos, batch_text, progress_text in generator: | |
| yield videos, batch_text, progress_text | |
| except Exception as e: | |
| yield [], f"Error in video generation: {str(e)}", "Generation error" | |
| return | |
| except Exception as e: | |
| yield [], f"Unexpected error: {str(e)}", "Error occurred" | |
| return | |
| finally: | |
| # Clean up temporary file | |
| try: | |
| if os.path.exists(temp_video_path): | |
| os.remove(temp_video_path) | |
| except Exception: | |
| pass # Ignore cleanup errors | |
| # Add event handlers | |
| i2v_prompt.change(fn=count_prompt_tokens, inputs=i2v_prompt, outputs=i2v_token_counter) | |
| i2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| def handle_i2v_gallery_select(evt: gr.SelectData) -> int: | |
| """Track selected index when I2V gallery item is clicked""" | |
| return evt.index | |
| def send_i2v_to_v2v( | |
| gallery: list, | |
| prompt: str, | |
| selected_index: int, | |
| width: int, | |
| height: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| lora1: str, | |
| lora2: str, | |
| lora3: str, | |
| lora4: str, | |
| lora1_multiplier: float, | |
| lora2_multiplier: float, | |
| lora3_multiplier: float, | |
| lora4_multiplier: float | |
| ) -> Tuple[Optional[str], str, int, int, int, int, int, int, float, float, str, str, str, str, float, float, float, float]: | |
| """Send the selected video and parameters from Image2Video tab to Video2Video tab""" | |
| if not gallery or selected_index is None or selected_index >= len(gallery): | |
| return None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, \ | |
| lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier | |
| selected_item = gallery[selected_index] | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| # Final cleanup for Gradio Video component | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| # Use the original width and height without doubling | |
| return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, cfg_scale, lora1, lora2, lora3, lora4, | |
| lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier) | |
| # Generate button handler | |
| i2v_generate_btn.click( | |
| fn=process_batch, | |
| inputs=[ | |
| i2v_prompt, width, height, | |
| i2v_batch_size, i2v_video_length, | |
| i2v_fps, i2v_infer_steps, i2v_seed, i2v_dit_folder, i2v_model, i2v_vae, i2v_te1, i2v_te2, | |
| i2v_save_path, i2v_flow_shift, i2v_cfg_scale, i2v_output_type, i2v_attn_mode, | |
| i2v_block_swap, i2v_exclude_single_blocks, i2v_use_split_attn, i2v_lora_folder, | |
| *i2v_lora_weights, *i2v_lora_multipliers, i2v_input, i2v_strength, i2v_use_fp8 | |
| ], | |
| outputs=[i2v_output, i2v_batch_progress, i2v_progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[i2v_batch_size], | |
| outputs=i2v_selected_index | |
| ) | |
| # Send to Video2Video | |
| i2v_output.select( | |
| fn=handle_i2v_gallery_select, | |
| outputs=i2v_selected_index | |
| ) | |
| i2v_send_to_v2v_btn.click( | |
| fn=send_i2v_to_v2v, | |
| inputs=[ | |
| i2v_output, i2v_prompt, i2v_selected_index, | |
| width, height, | |
| i2v_video_length, i2v_fps, i2v_infer_steps, | |
| i2v_seed, i2v_flow_shift, i2v_cfg_scale | |
| ] + i2v_lora_weights + i2v_lora_multipliers, | |
| outputs=[ | |
| v2v_input, v2v_prompt, | |
| v2v_width, v2v_height, | |
| v2v_video_length, v2v_fps, v2v_infer_steps, | |
| v2v_seed, v2v_flow_shift, v2v_cfg_scale | |
| ] + v2v_lora_weights + v2v_lora_multipliers | |
| ).then( | |
| fn=change_to_tab_two, inputs=None, outputs=[tabs] | |
| ) | |
| #Video Info | |
| def clean_video_path(video_path) -> str: | |
| """Extract clean video path from Gradio's various return formats""" | |
| print(f"Input video_path: {video_path}, type: {type(video_path)}") | |
| if isinstance(video_path, dict): | |
| path = video_path.get("name", "") | |
| elif isinstance(video_path, (tuple, list)): | |
| path = video_path[0] | |
| elif isinstance(video_path, str): | |
| path = video_path | |
| else: | |
| path = "" | |
| print(f"Cleaned path: {path}") | |
| return path | |
| def handle_video_upload(video_path: str) -> Dict: | |
| """Handle video upload and metadata extraction""" | |
| if not video_path: | |
| return {}, "No video uploaded" | |
| metadata = extract_video_metadata(video_path) | |
| if not metadata: | |
| return {}, "No metadata found in video" | |
| return metadata, "Metadata extracted successfully" | |
| def get_video_info(video_path: str) -> dict: | |
| try: | |
| probe = ffmpeg.probe(video_path) | |
| video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video') | |
| width = int(video_info['width']) | |
| height = int(video_info['height']) | |
| fps = eval(video_info['r_frame_rate']) # This converts '30/1' to 30.0 | |
| # Calculate total frames | |
| duration = float(probe['format']['duration']) | |
| total_frames = int(duration * fps) | |
| # Ensure video length does not exceed 201 frames | |
| if total_frames > 201: | |
| total_frames = 201 | |
| duration = total_frames / fps # Adjust duration accordingly | |
| return { | |
| 'width': width, | |
| 'height': height, | |
| 'fps': fps, | |
| 'total_frames': total_frames, | |
| 'duration': duration # Might be useful in some contexts | |
| } | |
| except Exception as e: | |
| print(f"Error extracting video info: {e}") | |
| return {} | |
| def extract_video_details(video_path: str) -> Tuple[dict, str]: | |
| metadata = extract_video_metadata(video_path) | |
| video_details = get_video_info(video_path) | |
| # Combine metadata with video details | |
| for key, value in video_details.items(): | |
| if key not in metadata: | |
| metadata[key] = value | |
| # Ensure video length does not exceed 201 frames | |
| if 'video_length' in metadata: | |
| metadata['video_length'] = min(metadata['video_length'], 201) | |
| else: | |
| metadata['video_length'] = min(video_details.get('total_frames', 0), 201) | |
| # Return both the updated metadata and a status message | |
| return metadata, "Video details extracted successfully" | |
| def send_parameters_to_tab(metadata: Dict, target_tab: str) -> Tuple[str, Dict]: | |
| """Create parameter mapping for target tab""" | |
| if not metadata: | |
| return "No parameters to send", {} | |
| tab_name = "Text2Video" if target_tab == "t2v" else "Video2Video" | |
| try: | |
| mapping = create_parameter_transfer_map(metadata, target_tab) | |
| return f"Parameters ready for {tab_name}", mapping | |
| except Exception as e: | |
| return f"Error: {str(e)}", {} | |
| video_input.upload( | |
| fn=extract_video_details, | |
| inputs=video_input, | |
| outputs=[metadata_output, status] | |
| ) | |
| send_to_t2v_btn.click( | |
| fn=lambda m: send_parameters_to_tab(m, "t2v"), | |
| inputs=metadata_output, | |
| outputs=[status, params_state] | |
| ).then( | |
| fn=change_to_tab_one, inputs=None, outputs=[tabs] | |
| ).then( | |
| lambda params: [ | |
| params.get("prompt", ""), | |
| params.get("width", 544), | |
| params.get("height", 544), | |
| params.get("batch_size", 1), | |
| params.get("video_length", 25), | |
| params.get("fps", 24), | |
| params.get("infer_steps", 30), | |
| params.get("seed", -1), | |
| params.get("model", "hunyuan/mp_rank_00_model_states.pt"), | |
| params.get("vae", "hunyuan/pytorch_model.pt"), | |
| params.get("te1", "hunyuan/llava_llama3_fp16.safetensors"), | |
| params.get("te2", "hunyuan/clip_l.safetensors"), | |
| params.get("save_path", "outputs"), | |
| params.get("flow_shift", 11.0), | |
| params.get("cfg_scale", 7.0), | |
| params.get("output_type", "video"), | |
| params.get("attn_mode", "sdpa"), | |
| params.get("block_swap", "0"), | |
| *[params.get(f"lora{i+1}", "") for i in range(4)], | |
| *[params.get(f"lora{i+1}_multiplier", 1.0) for i in range(4)] | |
| ] if params else [gr.update()]*26, | |
| inputs=params_state, | |
| outputs=[prompt, width, height, batch_size, video_length, fps, infer_steps, seed, | |
| model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap] + lora_weights + lora_multipliers | |
| ) | |
| # Text to Video generation | |
| generate_btn.click( | |
| fn=process_batch, | |
| inputs=[ | |
| prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, | |
| seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder, *lora_weights, *lora_multipliers, gr.Textbox(visible=False), gr.Number(visible=False), use_fp8 | |
| ], | |
| outputs=[video_output, batch_progress, progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[batch_size], | |
| outputs=selected_index | |
| ) | |
| # Update gallery selection handling | |
| def handle_gallery_select(evt: gr.SelectData) -> int: | |
| return evt.index | |
| # Track selected index when gallery item is clicked | |
| video_output.select( | |
| fn=handle_gallery_select, | |
| outputs=selected_index | |
| ) | |
| # Track selected index when Video2Video gallery item is clicked | |
| def handle_v2v_gallery_select(evt: gr.SelectData) -> int: | |
| """Handle gallery selection without automatically updating the input""" | |
| return evt.index | |
| # Update the gallery selection event | |
| v2v_output.select( | |
| fn=handle_v2v_gallery_select, | |
| outputs=v2v_selected_index | |
| ) | |
| # Send button handler with gallery selection | |
| def handle_send_button( | |
| gallery: list, | |
| prompt: str, | |
| idx: int, | |
| width: int, | |
| height: int, | |
| batch_size: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| lora1: str, | |
| lora2: str, | |
| lora3: str, | |
| lora4: str, | |
| lora1_multiplier: float, | |
| lora2_multiplier: float, | |
| lora3_multiplier: float, | |
| lora4_multiplier: float | |
| ) -> tuple: | |
| if not gallery or idx is None or idx >= len(gallery): | |
| return (None, "", width, height, batch_size, video_length, fps, infer_steps, | |
| seed, flow_shift, cfg_scale, | |
| lora1, lora2, lora3, lora4, | |
| lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, | |
| "") # Add empty string for negative_prompt in the return values | |
| # Auto-select first item if only one exists and no selection made | |
| if idx is None and len(gallery) == 1: | |
| idx = 0 | |
| selected_item = gallery[idx] | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| # Final cleanup for Gradio Video component | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| return ( | |
| str(video_path), | |
| prompt, | |
| width, | |
| height, | |
| batch_size, | |
| video_length, | |
| fps, | |
| infer_steps, | |
| seed, | |
| flow_shift, | |
| cfg_scale, | |
| lora1, | |
| lora2, | |
| lora3, | |
| lora4, | |
| lora1_multiplier, | |
| lora2_multiplier, | |
| lora3_multiplier, | |
| lora4_multiplier, | |
| "" # Add empty string for negative_prompt | |
| ) | |
| send_t2v_to_v2v_btn.click( | |
| fn=handle_send_button, | |
| inputs=[ | |
| video_output, prompt, selected_index, | |
| t2v_width, t2v_height, batch_size, video_length, | |
| fps, infer_steps, seed, flow_shift, cfg_scale | |
| ] + lora_weights + lora_multipliers, # Remove the string here | |
| outputs=[ | |
| v2v_input, | |
| v2v_prompt, | |
| v2v_width, | |
| v2v_height, | |
| v2v_batch_size, | |
| v2v_video_length, | |
| v2v_fps, | |
| v2v_infer_steps, | |
| v2v_seed, | |
| v2v_flow_shift, | |
| v2v_cfg_scale | |
| ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] | |
| ).then( | |
| fn=change_to_tab_two, inputs=None, outputs=[tabs] | |
| ) | |
| def handle_send_to_v2v(metadata: dict, video_path: str) -> Tuple[str, dict, str]: | |
| """Handle both parameters and video transfer""" | |
| status_msg, params = send_parameters_to_tab(metadata, "v2v") | |
| return status_msg, params, video_path | |
| def handle_info_to_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: | |
| """Handle both parameters and video transfer from Video Info to V2V tab""" | |
| if not video_path: | |
| return "No video selected", {}, None | |
| status_msg, params = send_parameters_to_tab(metadata, "v2v") | |
| # Just return the path directly | |
| return status_msg, params, video_path | |
| # Send button click handler | |
| send_to_v2v_btn.click( | |
| fn=handle_info_to_v2v, | |
| inputs=[metadata_output, video_input], | |
| outputs=[status, params_state, v2v_input] | |
| ).then( | |
| lambda params: [ | |
| params.get("v2v_prompt", ""), | |
| params.get("v2v_width", 544), | |
| params.get("v2v_height", 544), | |
| params.get("v2v_batch_size", 1), | |
| params.get("v2v_video_length", 25), | |
| params.get("v2v_fps", 24), | |
| params.get("v2v_infer_steps", 30), | |
| params.get("v2v_seed", -1), | |
| params.get("v2v_model", "hunyuan/mp_rank_00_model_states.pt"), | |
| params.get("v2v_vae", "hunyuan/pytorch_model.pt"), | |
| params.get("v2v_te1", "hunyuan/llava_llama3_fp16.safetensors"), | |
| params.get("v2v_te2", "hunyuan/clip_l.safetensors"), | |
| params.get("v2v_save_path", "outputs"), | |
| params.get("v2v_flow_shift", 11.0), | |
| params.get("v2v_cfg_scale", 7.0), | |
| params.get("v2v_output_type", "video"), | |
| params.get("v2v_attn_mode", "sdpa"), | |
| params.get("v2v_block_swap", "0"), | |
| *[params.get(f"v2v_lora_weights[{i}]", "") for i in range(4)], | |
| *[params.get(f"v2v_lora_multipliers[{i}]", 1.0) for i in range(4)] | |
| ] if params else [gr.update()] * 26, | |
| inputs=params_state, | |
| outputs=[ | |
| v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, | |
| v2v_fps, v2v_infer_steps, v2v_seed, v2v_model, v2v_vae, v2v_te1, | |
| v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, | |
| v2v_attn_mode, v2v_block_swap | |
| ] + v2v_lora_weights + v2v_lora_multipliers | |
| ).then( | |
| lambda: print(f"Tabs object: {tabs}"), # Debug print | |
| outputs=None | |
| ).then( | |
| fn=change_to_tab_two, inputs=None, outputs=[tabs] | |
| ) | |
| # Handler for sending selected video from Video2Video gallery to input | |
| def handle_v2v_send_button(gallery: list, prompt: str, idx: int) -> Tuple[Optional[str], str]: | |
| """Send the currently selected video in V2V gallery to V2V input""" | |
| if not gallery or idx is None or idx >= len(gallery): | |
| return None, "" | |
| selected_item = gallery[idx] | |
| video_path = None | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, tuple): | |
| video_path = selected_item[0] # Gallery returns (path, caption) | |
| elif isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, str): | |
| video_path = selected_item | |
| if not video_path: | |
| return None, "" | |
| # Check if the file exists and is accessible | |
| if not os.path.exists(video_path): | |
| print(f"Warning: Video file not found at {video_path}") | |
| return None, "" | |
| return video_path, prompt | |
| v2v_send_to_input_btn.click( | |
| fn=handle_v2v_send_button, | |
| inputs=[v2v_output, v2v_prompt, v2v_selected_index], | |
| outputs=[v2v_input, v2v_prompt] | |
| ).then( | |
| lambda: gr.update(visible=True), # Ensure the video input is visible | |
| outputs=v2v_input | |
| ) | |
| # Video to Video generation | |
| v2v_generate_btn.click( | |
| fn=process_batch, | |
| inputs=[ | |
| v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, | |
| v2v_fps, v2v_infer_steps, v2v_seed, v2v_dit_folder, v2v_model, v2v_vae, v2v_te1, v2v_te2, | |
| v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, | |
| v2v_block_swap, v2v_exclude_single_blocks, v2v_use_split_attn, v2v_lora_folder, | |
| *v2v_lora_weights, *v2v_lora_multipliers, v2v_input, v2v_strength, | |
| v2v_negative_prompt, v2v_cfg_scale, v2v_split_uncond, v2v_use_fp8 | |
| ], | |
| outputs=[v2v_output, v2v_batch_progress, v2v_progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[v2v_batch_size], | |
| outputs=v2v_selected_index | |
| ) | |
| refresh_outputs = [model] # Add model dropdown to outputs | |
| for i in range(4): | |
| refresh_outputs.extend([lora_weights[i], lora_multipliers[i]]) | |
| refresh_btn.click( | |
| fn=update_dit_and_lora_dropdowns, | |
| inputs=[dit_folder, lora_folder, model] + lora_weights + lora_multipliers, | |
| outputs=refresh_outputs | |
| ) | |
| # Image2Video refresh | |
| i2v_refresh_outputs = [i2v_model] # Add model dropdown to outputs | |
| for i in range(4): | |
| i2v_refresh_outputs.extend([i2v_lora_weights[i], i2v_lora_multipliers[i]]) | |
| i2v_refresh_btn.click( | |
| fn=update_dit_and_lora_dropdowns, | |
| inputs=[i2v_dit_folder, i2v_lora_folder, i2v_model] + i2v_lora_weights + i2v_lora_multipliers, | |
| outputs=i2v_refresh_outputs | |
| ) | |
| # Video2Video refresh | |
| v2v_refresh_outputs = [v2v_model] # Add model dropdown to outputs | |
| for i in range(4): | |
| v2v_refresh_outputs.extend([v2v_lora_weights[i], v2v_lora_multipliers[i]]) | |
| v2v_refresh_btn.click( | |
| fn=update_dit_and_lora_dropdowns, | |
| inputs=[v2v_dit_folder, v2v_lora_folder, v2v_model] + v2v_lora_weights + v2v_lora_multipliers, | |
| outputs=v2v_refresh_outputs | |
| ) | |
| # WanX-i2v tab connections | |
| wanx_prompt.change(fn=count_prompt_tokens, inputs=wanx_prompt, outputs=wanx_token_counter) | |
| wanx_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| # Image input handling for WanX-i2v | |
| wanx_input.change( | |
| fn=update_wanx_image_dimensions, | |
| inputs=[wanx_input], | |
| outputs=[wanx_original_dims, wanx_width, wanx_height] | |
| ) | |
| # Scale slider handling for WanX-i2v | |
| wanx_scale_slider.change( | |
| fn=update_wanx_from_scale, | |
| inputs=[wanx_scale_slider, wanx_original_dims], | |
| outputs=[wanx_width, wanx_height] | |
| ) | |
| # Width/height calculation buttons for WanX-i2v | |
| wanx_calc_width_btn.click( | |
| fn=calculate_wanx_width, | |
| inputs=[wanx_height, wanx_original_dims], | |
| outputs=[wanx_width] | |
| ) | |
| wanx_calc_height_btn.click( | |
| fn=calculate_wanx_height, | |
| inputs=[wanx_width, wanx_original_dims], | |
| outputs=[wanx_height] | |
| ) | |
| # Flow shift recommendation buttons | |
| wanx_recommend_flow_btn.click( | |
| fn=recommend_wanx_flow_shift, | |
| inputs=[wanx_width, wanx_height], | |
| outputs=[wanx_flow_shift] | |
| ) | |
| wanx_t2v_recommend_flow_btn.click( | |
| fn=recommend_wanx_flow_shift, | |
| inputs=[wanx_t2v_width, wanx_t2v_height], | |
| outputs=[wanx_t2v_flow_shift] | |
| ) | |
| # Generate button handler | |
| wanx_generate_btn.click( | |
| fn=wanx_generate_video_batch, | |
| inputs=[ | |
| wanx_prompt, | |
| wanx_negative_prompt, | |
| wanx_width, | |
| wanx_height, | |
| wanx_video_length, | |
| wanx_fps, | |
| wanx_infer_steps, | |
| wanx_flow_shift, | |
| wanx_guidance_scale, | |
| wanx_seed, | |
| wanx_task, | |
| wanx_dit_path, | |
| wanx_vae_path, | |
| wanx_t5_path, | |
| wanx_clip_path, | |
| wanx_save_path, | |
| wanx_output_type, | |
| wanx_sample_solver, | |
| wanx_attn_mode, | |
| wanx_block_swap, | |
| wanx_fp8, | |
| wanx_fp8_t5, | |
| wanx_batch_size, | |
| wanx_input, # Image input | |
| wanx_lora_folder, | |
| *wanx_lora_weights, | |
| *wanx_lora_multipliers, | |
| wanx_exclude_single_blocks | |
| ], | |
| outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[wanx_batch_size], | |
| outputs=skyreels_selected_index | |
| ) | |
| # Gallery selection handling | |
| wanx_output.select( | |
| fn=handle_wanx_gallery_select, | |
| outputs=skyreels_selected_index # Reuse the skyreels_selected_index | |
| ) | |
| # Send to Video2Video handler | |
| wanx_send_to_v2v_btn.click( | |
| fn=send_wanx_to_v2v, | |
| inputs=[ | |
| wanx_output, | |
| wanx_prompt, | |
| skyreels_selected_index, # Reuse the skyreels_selected_index | |
| wanx_width, | |
| wanx_height, | |
| wanx_video_length, | |
| wanx_fps, | |
| wanx_infer_steps, | |
| wanx_seed, | |
| wanx_flow_shift, | |
| wanx_guidance_scale, | |
| wanx_negative_prompt | |
| ], | |
| outputs=[ | |
| v2v_input, | |
| v2v_prompt, | |
| v2v_width, | |
| v2v_height, | |
| v2v_video_length, | |
| v2v_fps, | |
| v2v_infer_steps, | |
| v2v_seed, | |
| v2v_flow_shift, | |
| v2v_cfg_scale, | |
| v2v_negative_prompt | |
| ] | |
| ).then( | |
| fn=change_to_tab_two, | |
| inputs=None, | |
| outputs=[tabs] | |
| ) | |
| # Add state for T2V tab selected index | |
| wanx_t2v_selected_index = gr.State(value=None) | |
| # Connect prompt token counter | |
| wanx_t2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_t2v_prompt, outputs=wanx_t2v_token_counter) | |
| # Stop button handler | |
| wanx_t2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| # Flow shift recommendation button | |
| wanx_t2v_recommend_flow_btn.click( | |
| fn=recommend_wanx_flow_shift, | |
| inputs=[wanx_t2v_width, wanx_t2v_height], | |
| outputs=[wanx_t2v_flow_shift] | |
| ) | |
| # Task change handler to update CLIP visibility and path | |
| def update_clip_visibility(task): | |
| is_i2v = "i2v" in task | |
| return gr.update(visible=is_i2v) | |
| wanx_t2v_task.change( | |
| fn=update_clip_visibility, | |
| inputs=[wanx_t2v_task], | |
| outputs=[wanx_t2v_clip_path] | |
| ) | |
| # Generate button handler for T2V | |
| wanx_t2v_generate_btn.click( | |
| fn=wanx_generate_video_batch, | |
| inputs=[ | |
| wanx_t2v_prompt, | |
| wanx_t2v_negative_prompt, | |
| wanx_t2v_width, | |
| wanx_t2v_height, | |
| wanx_t2v_video_length, | |
| wanx_t2v_fps, | |
| wanx_t2v_infer_steps, | |
| wanx_t2v_flow_shift, | |
| wanx_t2v_guidance_scale, | |
| wanx_t2v_seed, | |
| wanx_t2v_task, | |
| wanx_t2v_dit_path, | |
| wanx_t2v_vae_path, | |
| wanx_t2v_t5_path, | |
| wanx_t2v_clip_path, | |
| wanx_t2v_save_path, | |
| wanx_t2v_output_type, | |
| wanx_t2v_sample_solver, | |
| wanx_t2v_attn_mode, | |
| wanx_t2v_block_swap, | |
| wanx_t2v_fp8, | |
| wanx_t2v_fp8_t5, | |
| wanx_t2v_batch_size, | |
| wanx_t2v_lora_folder, | |
| *wanx_t2v_lora_weights, | |
| *wanx_t2v_lora_multipliers, | |
| wanx_t2v_exclude_single_blocks | |
| ], | |
| outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[wanx_t2v_batch_size], | |
| outputs=wanx_t2v_selected_index | |
| ) | |
| # Gallery selection handling | |
| wanx_t2v_output.select( | |
| fn=handle_wanx_t2v_gallery_select, | |
| outputs=wanx_t2v_selected_index | |
| ) | |
| # Send to Video2Video handler | |
| wanx_t2v_send_to_v2v_btn.click( | |
| fn=send_wanx_t2v_to_v2v, | |
| inputs=[ | |
| wanx_t2v_output, | |
| wanx_t2v_prompt, | |
| wanx_t2v_selected_index, | |
| wanx_t2v_width, | |
| wanx_t2v_height, | |
| wanx_t2v_video_length, | |
| wanx_t2v_fps, | |
| wanx_t2v_infer_steps, | |
| wanx_t2v_seed, | |
| wanx_t2v_flow_shift, | |
| wanx_t2v_guidance_scale, | |
| wanx_t2v_negative_prompt | |
| ], | |
| outputs=[ | |
| v2v_input, | |
| v2v_prompt, | |
| v2v_width, | |
| v2v_height, | |
| v2v_video_length, | |
| v2v_fps, | |
| v2v_infer_steps, | |
| v2v_seed, | |
| v2v_flow_shift, | |
| v2v_cfg_scale, | |
| v2v_negative_prompt | |
| ] | |
| ).then( | |
| fn=change_to_tab_two, | |
| inputs=None, | |
| outputs=[tabs] | |
| ) | |
| # Refresh handlers for WanX-i2v | |
| wanx_refresh_outputs = [] | |
| for i in range(4): | |
| wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) | |
| wanx_refresh_btn.click( | |
| fn=update_lora_dropdowns, | |
| inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, | |
| outputs=wanx_refresh_outputs | |
| ) | |
| # Refresh handlers for WanX-t2v | |
| wanx_t2v_refresh_outputs = [] | |
| for i in range(4): | |
| wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) | |
| wanx_t2v_refresh_btn.click( | |
| fn=update_lora_dropdowns, | |
| inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, | |
| outputs=wanx_t2v_refresh_outputs | |
| ) | |
| demo.queue().launch(server_name="0.0.0.0", share=False) |