Spaces:
Running
Running
| import gradio as gr | |
| from gradio import update as gr_update | |
| import subprocess | |
| import threading | |
| import time | |
| import re | |
| import os | |
| import random | |
| import tiktoken | |
| import sys | |
| import ffmpeg | |
| from typing import List, Tuple, Optional, Generator, Dict | |
| import json | |
| from gradio import themes | |
| from gradio.themes.utils import colors | |
| import subprocess | |
| from PIL import Image | |
| import math | |
| import cv2 | |
| import glob | |
| import shutil | |
| from pathlib import Path | |
| import logging | |
| from datetime import datetime | |
| from tqdm import tqdm | |
| # Add global stop event | |
| stop_event = threading.Event() | |
| logger = logging.getLogger(__name__) | |
| def process_hunyuani2v_video( | |
| prompt: str, | |
| width: int, | |
| height: int, | |
| batch_size: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| dit_folder: str, | |
| model: str, | |
| vae: str, | |
| te1: str, | |
| te2: str, | |
| save_path: str, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| output_type: str, | |
| attn_mode: str, | |
| block_swap: int, | |
| exclude_single_blocks: bool, | |
| use_split_attn: bool, | |
| lora_folder: str, | |
| lora1: str = "", | |
| lora2: str = "", | |
| lora3: str = "", | |
| lora4: str = "", | |
| lora1_multiplier: float = 1.0, | |
| lora2_multiplier: float = 1.0, | |
| lora3_multiplier: float = 1.0, | |
| lora4_multiplier: float = 1.0, | |
| video_path: Optional[str] = None, | |
| image_path: Optional[str] = None, | |
| strength: Optional[float] = None, | |
| negative_prompt: Optional[str] = None, | |
| embedded_cfg_scale: Optional[float] = None, | |
| split_uncond: Optional[bool] = None, | |
| guidance_scale: Optional[float] = None, | |
| use_fp8: bool = True, | |
| clip_vision_path: Optional[str] = None, | |
| i2v_stability: bool = False, | |
| fp8_fast: bool = False, | |
| compile_model: bool = False, | |
| compile_backend: str = "inductor", | |
| compile_mode: str = "max-autotune-no-cudagraphs", | |
| compile_dynamic: bool = False, | |
| compile_fullgraph: bool = False | |
| ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
| """Generate a single video with the hunyuani2v script with updated parameters""" | |
| global stop_event | |
| if stop_event.is_set(): | |
| yield [], "", "" | |
| return | |
| # Determine if this is a SkyReels model and what type | |
| is_skyreels = "skyreels" in model.lower() | |
| is_skyreels_i2v = is_skyreels and "i2v" in model.lower() | |
| is_skyreels_t2v = is_skyreels and "t2v" in model.lower() | |
| # Set defaults for hunyuani2v specific parameters | |
| if is_skyreels: | |
| # Force certain parameters for SkyReels | |
| if negative_prompt is None: | |
| negative_prompt = "" | |
| if embedded_cfg_scale is None: | |
| embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels | |
| if split_uncond is None: | |
| split_uncond = True | |
| if guidance_scale is None: | |
| guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided | |
| else: | |
| embedded_cfg_scale = cfg_scale | |
| if os.path.isabs(model): | |
| model_path = model | |
| else: | |
| model_path = os.path.normpath(os.path.join(dit_folder, model)) | |
| env = os.environ.copy() | |
| env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| env["BATCH_RUN_ID"] = f"{time.time()}" | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| else: | |
| batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) | |
| if batch_size > 1: # Only modify seed for batch generation | |
| current_seed = (seed + batch_id * 100003) % (2**32) | |
| else: | |
| current_seed = seed | |
| clear_cuda_cache() | |
| # Now use hv_generate_video_with_hunyuani2v.py instead | |
| command = [ | |
| sys.executable, | |
| "hv_generate_video_with_hunyuani2v.py", | |
| "--dit", model_path, | |
| "--vae", vae, | |
| "--text_encoder1", te1, | |
| "--text_encoder2", te2, | |
| "--prompt", prompt, | |
| "--video_size", str(height), str(width), | |
| "--video_length", str(video_length), | |
| "--fps", str(fps), | |
| "--infer_steps", str(infer_steps), | |
| "--save_path", save_path, | |
| "--seed", str(current_seed), | |
| "--flow_shift", str(flow_shift), | |
| "--embedded_cfg_scale", str(cfg_scale), | |
| "--output_type", output_type, | |
| "--attn_mode", attn_mode, | |
| "--blocks_to_swap", str(block_swap), | |
| "--fp8_llm", | |
| "--vae_chunk_size", "32", | |
| "--vae_spatial_tile_sample_min_size", "128" | |
| ] | |
| if use_fp8: | |
| command.append("--fp8") | |
| # Add new parameters specific to hunyuani2v script | |
| if clip_vision_path: | |
| command.extend(["--clip_vision_path", clip_vision_path]) | |
| if i2v_stability: | |
| command.append("--i2v_stability") | |
| if fp8_fast: | |
| command.append("--fp8_fast") | |
| if compile_model: | |
| command.append("--compile") | |
| command.extend([ | |
| "--compile_args", | |
| compile_backend, | |
| compile_mode, | |
| str(compile_dynamic).lower(), | |
| str(compile_fullgraph).lower() | |
| ]) | |
| # Add negative prompt and embedded cfg scale | |
| command.extend(["--guidance_scale", str(guidance_scale)]) | |
| if negative_prompt: | |
| command.extend(["--negative_prompt", negative_prompt]) | |
| if split_uncond: | |
| command.append("--split_uncond") | |
| # Add LoRA weights and multipliers if provided | |
| valid_loras = [] | |
| for weight, mult in zip([lora1, lora2, lora3, lora4], | |
| [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): | |
| if weight and weight != "None": | |
| valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
| if valid_loras: | |
| weights = [weight for weight, _ in valid_loras] | |
| multipliers = [str(mult) for _, mult in valid_loras] | |
| command.extend(["--lora_weight"] + weights) | |
| command.extend(["--lora_multiplier"] + multipliers) | |
| if exclude_single_blocks: | |
| command.append("--exclude_single_blocks") | |
| if use_split_attn: | |
| command.append("--split_attn") | |
| # Handle input paths | |
| if video_path: | |
| command.extend(["--video_path", video_path]) | |
| if strength is not None: | |
| command.extend(["--strength", str(strength)]) | |
| elif image_path: | |
| command.extend(["--image_path", image_path]) | |
| # Only add strength parameter for non-SkyReels I2V models | |
| # SkyReels I2V doesn't use strength parameter for image-to-video generation | |
| if strength is not None and not is_skyreels_i2v: | |
| command.extend(["--strength", str(strength)]) | |
| print(f"{command}") | |
| p = subprocess.Popen( | |
| command, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| env=env, | |
| text=True, | |
| encoding='utf-8', | |
| errors='replace', | |
| bufsize=1 | |
| ) | |
| videos = [] | |
| while True: | |
| if stop_event.is_set(): | |
| p.terminate() | |
| p.wait() | |
| yield [], "", "Generation stopped by user." | |
| return | |
| line = p.stdout.readline() | |
| if not line: | |
| if p.poll() is not None: | |
| break | |
| continue | |
| print(line, end='') | |
| if '|' in line and '%' in line and '[' in line and ']' in line: | |
| yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() | |
| p.stdout.close() | |
| p.wait() | |
| clear_cuda_cache() | |
| time.sleep(0.5) | |
| # Collect generated video | |
| save_path_abs = os.path.abspath(save_path) | |
| if os.path.exists(save_path_abs): | |
| all_videos = sorted( | |
| [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
| key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
| reverse=True | |
| ) | |
| matching_videos = [v for v in all_videos if f"_{current_seed}" in v] | |
| if matching_videos: | |
| video_path = os.path.join(save_path_abs, matching_videos[0]) | |
| # Collect parameters for metadata | |
| parameters = { | |
| "prompt": prompt, | |
| "width": width, | |
| "height": height, | |
| "video_length": video_length, | |
| "fps": fps, | |
| "infer_steps": infer_steps, | |
| "seed": current_seed, | |
| "model": model, | |
| "vae": vae, | |
| "te1": te1, | |
| "te2": te2, | |
| "save_path": save_path, | |
| "flow_shift": flow_shift, | |
| "cfg_scale": cfg_scale, | |
| "output_type": output_type, | |
| "attn_mode": attn_mode, | |
| "block_swap": block_swap, | |
| "lora_weights": [lora1, lora2, lora3, lora4], | |
| "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], | |
| "input_video": video_path if video_path else None, | |
| "input_image": image_path if image_path else None, | |
| "strength": strength, | |
| "negative_prompt": negative_prompt, | |
| "embedded_cfg_scale": embedded_cfg_scale, | |
| "clip_vision_path": clip_vision_path, | |
| "i2v_stability": i2v_stability, | |
| "fp8_fast": fp8_fast, | |
| "compile_model": compile_model | |
| } | |
| add_metadata_to_video(video_path, parameters) | |
| videos.append((str(video_path), f"Seed: {current_seed}")) | |
| yield videos, f"Completed (seed: {current_seed})", "" | |
| # Now let's create a new batch processing function that uses the hunyuani2v function | |
| def process_hunyuani2v_batch( | |
| prompt: str, | |
| width: int, | |
| height: int, | |
| batch_size: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| dit_folder: str, | |
| model: str, | |
| vae: str, | |
| te1: str, | |
| te2: str, | |
| save_path: str, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| output_type: str, | |
| attn_mode: str, | |
| block_swap: int, | |
| exclude_single_blocks: bool, | |
| use_split_attn: bool, | |
| lora_folder: str, | |
| *args | |
| ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
| """Process a batch of videos using the hunyuani2v script""" | |
| global stop_event | |
| stop_event.clear() | |
| all_videos = [] | |
| progress_text = "Starting generation..." | |
| yield [], "Preparing...", progress_text | |
| # Extract additional arguments | |
| num_lora_weights = 4 | |
| lora_weights = args[:num_lora_weights] | |
| lora_multipliers = args[num_lora_weights:num_lora_weights*2] | |
| # New parameters for hunyuani2v | |
| # Base parameter list index after lora weights and multipliers | |
| base_idx = num_lora_weights*2 | |
| # Extract parameters | |
| input_path = args[base_idx] if len(args) > base_idx else None | |
| strength = float(args[base_idx+1]) if len(args) > base_idx+1 and args[base_idx+1] is not None else None | |
| negative_prompt = str(args[base_idx+2]) if len(args) > base_idx+2 and args[base_idx+2] is not None else None | |
| guidance_scale = float(args[base_idx+3]) if len(args) > base_idx+3 and args[base_idx+3] is not None else cfg_scale | |
| split_uncond = bool(args[base_idx+4]) if len(args) > base_idx+4 else None | |
| use_fp8 = bool(args[base_idx+5]) if len(args) > base_idx+5 else True | |
| # New hunyuani2v parameters | |
| clip_vision_path = str(args[base_idx+6]) if len(args) > base_idx+6 and args[base_idx+6] is not None else None | |
| i2v_stability = bool(args[base_idx+7]) if len(args) > base_idx+7 else False | |
| fp8_fast = bool(args[base_idx+8]) if len(args) > base_idx+8 else False | |
| compile_model = bool(args[base_idx+9]) if len(args) > base_idx+9 else False | |
| compile_backend = str(args[base_idx+10]) if len(args) > base_idx+10 and args[base_idx+10] is not None else "inductor" | |
| compile_mode = str(args[base_idx+11]) if len(args) > base_idx+11 and args[base_idx+11] is not None else "max-autotune-no-cudagraphs" | |
| compile_dynamic = bool(args[base_idx+12]) if len(args) > base_idx+12 else False | |
| compile_fullgraph = bool(args[base_idx+13]) if len(args) > base_idx+13 else False | |
| embedded_cfg_scale = cfg_scale | |
| for i in range(batch_size): | |
| if stop_event.is_set(): | |
| break | |
| batch_text = f"Generating video {i + 1} of {batch_size}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Handle different input types | |
| video_path = None | |
| image_path = None | |
| if input_path: | |
| is_image = False | |
| lower_path = input_path.lower() | |
| image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') | |
| is_image = any(lower_path.endswith(ext) for ext in image_extensions) | |
| if is_image: | |
| image_path = input_path | |
| else: | |
| video_path = input_path | |
| # Prepare arguments for process_hunyuani2v_video | |
| current_seed = seed + i if seed != -1 and batch_size > 1 else seed if seed != -1 else -1 | |
| hunyuani2v_args = [ | |
| prompt, width, height, batch_size, video_length, fps, infer_steps, | |
| current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder | |
| ] | |
| hunyuani2v_args.extend(lora_weights) | |
| hunyuani2v_args.extend(lora_multipliers) | |
| hunyuani2v_args.extend([ | |
| video_path, image_path, strength, negative_prompt, embedded_cfg_scale, | |
| split_uncond, guidance_scale, use_fp8, clip_vision_path, i2v_stability, | |
| fp8_fast, compile_model, compile_backend, compile_mode, compile_dynamic, compile_fullgraph | |
| ]) | |
| for videos, status, progress in process_hunyuani2v_video(*hunyuani2v_args): | |
| if videos: | |
| all_videos.extend(videos) | |
| yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
| yield all_videos, "Batch complete", "" | |
| def variance_of_laplacian(image): | |
| """ | |
| Compute the variance of the Laplacian of the image. | |
| Higher variance indicates a sharper image. | |
| """ | |
| return cv2.Laplacian(image, cv2.CV_64F).var() | |
| def extract_sharpest_frame(video_path, frames_to_check=30): | |
| """ | |
| Extract the sharpest frame from the last N frames of the video. | |
| Args: | |
| video_path (str): Path to the video file | |
| frames_to_check (int): Number of frames from the end to check | |
| Returns: | |
| tuple: (temp_image_path, frame_number, sharpness_score) | |
| """ | |
| print(f"\n=== Extracting sharpest frame from the last {frames_to_check} frames ===") | |
| print(f"Input video path: {video_path}") | |
| if not video_path or not os.path.exists(video_path): | |
| print("❌ Error: Video file does not exist") | |
| return None, None, None | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print("❌ Error: Failed to open video file") | |
| return None, None, None | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| print(f"Total frames detected: {total_frames}, FPS: {fps:.2f}") | |
| if total_frames < 1: | |
| print("❌ Error: Video contains 0 frames") | |
| return None, None, None | |
| # Determine how many frames to check (the last N frames) | |
| if frames_to_check > total_frames: | |
| frames_to_check = total_frames | |
| start_frame = 0 | |
| else: | |
| start_frame = total_frames - frames_to_check | |
| print(f"Checking frames {start_frame} to {total_frames-1}") | |
| # Find the sharpest frame | |
| sharpest_frame = None | |
| max_sharpness = -1 | |
| sharpest_frame_number = -1 | |
| # Set starting position | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| # Process frames with a progress bar | |
| with tqdm(total=frames_to_check, desc="Finding sharpest frame") as pbar: | |
| frame_idx = start_frame | |
| while frame_idx < total_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert to grayscale and calculate sharpness | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| sharpness = variance_of_laplacian(gray) | |
| # Update if this is the sharpest frame so far | |
| if sharpness > max_sharpness: | |
| max_sharpness = sharpness | |
| sharpest_frame = frame.copy() | |
| sharpest_frame_number = frame_idx | |
| frame_idx += 1 | |
| pbar.update(1) | |
| cap.release() | |
| if sharpest_frame is None: | |
| print("❌ Error: Failed to find a sharp frame") | |
| return None, None, None | |
| # Prepare output path | |
| temp_dir = os.path.abspath("temp_frames") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| temp_path = os.path.join(temp_dir, f"sharpest_frame_{os.path.basename(video_path)}.png") | |
| print(f"Saving frame to: {temp_path}") | |
| # Write and verify | |
| if not cv2.imwrite(temp_path, sharpest_frame): | |
| print("❌ Error: Failed to write frame to file") | |
| return None, None, None | |
| if not os.path.exists(temp_path): | |
| print("❌ Error: Output file not created") | |
| return None, None, None | |
| # Calculate frame time in seconds | |
| frame_time = sharpest_frame_number / fps | |
| print(f"✅ Extracted sharpest frame: {sharpest_frame_number} (at {frame_time:.2f}s) with sharpness {max_sharpness:.2f}") | |
| return temp_path, sharpest_frame_number, max_sharpness | |
| except Exception as e: | |
| print(f"❌ Unexpected error: {str(e)}") | |
| return None, None, None | |
| finally: | |
| if 'cap' in locals(): | |
| cap.release() | |
| def trim_video_to_frame(video_path, frame_number, output_dir="outputs"): | |
| """ | |
| Trim video up to the specified frame and save as a new video. | |
| Args: | |
| video_path (str): Path to the video file | |
| frame_number (int): Frame number to trim to | |
| output_dir (str): Directory to save the trimmed video | |
| Returns: | |
| str: Path to the trimmed video file | |
| """ | |
| print(f"\n=== Trimming video to frame {frame_number} ===") | |
| if not video_path or not os.path.exists(video_path): | |
| print("❌ Error: Video file does not exist") | |
| return None | |
| try: | |
| # Get video information | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print("❌ Error: Failed to open video file") | |
| return None | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| cap.release() | |
| # Calculate time in seconds | |
| time_seconds = frame_number / fps | |
| # Create output directory if it doesn't exist | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Generate output filename | |
| timestamp = f"{int(time_seconds)}s" | |
| base_name = Path(video_path).stem | |
| output_file = os.path.join(output_dir, f"{base_name}_trimmed_to_{timestamp}.mp4") | |
| # Use ffmpeg to trim the video | |
| ( | |
| ffmpeg | |
| .input(video_path) | |
| .output(output_file, to=time_seconds, c="copy") | |
| .global_args('-y') # Overwrite output files | |
| .run(quiet=True) | |
| ) | |
| if not os.path.exists(output_file): | |
| print("❌ Error: Failed to create trimmed video") | |
| return None | |
| print(f"✅ Successfully trimmed video to {time_seconds:.2f}s: {output_file}") | |
| return output_file | |
| except Exception as e: | |
| print(f"❌ Error trimming video: {str(e)}") | |
| return None | |
| def send_sharpest_frame_handler(gallery, selected_idx, frames_to_check=30): | |
| """ | |
| Extract the sharpest frame from the last N frames of the selected video | |
| Args: | |
| gallery: Gradio gallery component with videos | |
| selected_idx: Index of the selected video | |
| frames_to_check: Number of frames from the end to check | |
| Returns: | |
| tuple: (image_path, video_path, frame_number, sharpness) | |
| """ | |
| if gallery is None or not gallery: | |
| return None, None, None, "No videos in gallery" | |
| if selected_idx is None and len(gallery) == 1: | |
| selected_idx = 0 | |
| if selected_idx is None or selected_idx >= len(gallery): | |
| return None, None, None, "No video selected" | |
| # Get the video path | |
| item = gallery[selected_idx] | |
| if isinstance(item, tuple): | |
| video_path = item[0] | |
| elif isinstance(item, dict): | |
| video_path = item.get('name') or item.get('data') | |
| else: | |
| video_path = str(item) | |
| # Extract the sharpest frame | |
| image_path, frame_number, sharpness = extract_sharpest_frame(video_path, frames_to_check) | |
| if image_path is None: | |
| return None, None, None, "Failed to extract sharpest frame" | |
| return image_path, video_path, frame_number, f"Extracted frame {frame_number} with sharpness {sharpness:.2f}" | |
| def trim_and_prepare_for_extension(video_path, frame_number, save_path="outputs"): | |
| """ | |
| Trim the video to the specified frame and prepare for extension. | |
| Args: | |
| video_path: Path to the video file | |
| frame_number: Frame number to trim to | |
| save_path: Directory to save the trimmed video | |
| Returns: | |
| tuple: (trimmed_video_path, status_message) | |
| """ | |
| if not video_path or not os.path.exists(video_path): | |
| return None, "No video selected or video file does not exist" | |
| if frame_number is None: | |
| return None, "No frame number provided, please extract sharpest frame first" | |
| # Trim the video | |
| trimmed_video = trim_video_to_frame(video_path, frame_number, save_path) | |
| if trimmed_video is None: | |
| return None, "Failed to trim video" | |
| return trimmed_video, f"Video trimmed to frame {frame_number} and ready for extension" | |
| def send_last_frame_handler(gallery, selected_idx): | |
| """Handle sending last frame to input with better error handling""" | |
| if gallery is None or not gallery: | |
| return None, None | |
| if selected_idx is None and len(gallery) == 1: | |
| selected_idx = 0 | |
| if selected_idx is None or selected_idx >= len(gallery): | |
| return None, None | |
| # Get the frame and video path | |
| frame = handle_last_frame_transfer(gallery, selected_idx) | |
| video_path = None | |
| if selected_idx < len(gallery): | |
| item = gallery[selected_idx] | |
| video_path = parse_video_path(item) | |
| return frame, video_path | |
| def extract_last_frame(video_path: str) -> Optional[str]: | |
| """Extract last frame from video and return temporary image path with error handling""" | |
| print(f"\n=== Starting frame extraction ===") | |
| print(f"Input video path: {video_path}") | |
| if not video_path or not os.path.exists(video_path): | |
| print("❌ Error: Video file does not exist") | |
| return None | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print("❌ Error: Failed to open video file") | |
| return None | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| print(f"Total frames detected: {total_frames}") | |
| if total_frames < 1: | |
| print("❌ Error: Video contains 0 frames") | |
| return None | |
| # Extract last frame | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) | |
| success, frame = cap.read() | |
| if not success or frame is None: | |
| print("❌ Error: Failed to read last frame") | |
| return None | |
| # Prepare output path | |
| temp_dir = os.path.abspath("temp_frames") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| temp_path = os.path.join(temp_dir, f"last_frame_{os.path.basename(video_path)}.png") | |
| print(f"Saving frame to: {temp_path}") | |
| # Write and verify | |
| if not cv2.imwrite(temp_path, frame): | |
| print("❌ Error: Failed to write frame to file") | |
| return None | |
| if not os.path.exists(temp_path): | |
| print("❌ Error: Output file not created") | |
| return None | |
| print("✅ Frame extraction successful") | |
| return temp_path | |
| except Exception as e: | |
| print(f"❌ Unexpected error: {str(e)}") | |
| return None | |
| finally: | |
| if 'cap' in locals(): | |
| cap.release() | |
| def handle_last_frame_transfer(gallery: list, selected_idx: int) -> Optional[str]: | |
| """Improved frame transfer with video input validation""" | |
| try: | |
| if gallery is None or not gallery: | |
| raise ValueError("No videos generated yet") | |
| if selected_idx is None: | |
| # Auto-select last generated video if batch_size=1 | |
| if len(gallery) == 1: | |
| selected_idx = 0 | |
| else: | |
| raise ValueError("Please select a video first") | |
| if selected_idx >= len(gallery): | |
| raise ValueError("Invalid selection index") | |
| item = gallery[selected_idx] | |
| # Video file existence check | |
| video_path = parse_video_path(item) | |
| if not os.path.exists(video_path): | |
| raise FileNotFoundError(f"Video file missing: {video_path}") | |
| return extract_last_frame(video_path) | |
| except Exception as e: | |
| print(f"Frame transfer failed: {str(e)}") | |
| return None | |
| def parse_video_path(item) -> str: | |
| """Parse different gallery item formats""" | |
| if isinstance(item, tuple): | |
| return item[0] | |
| elif isinstance(item, dict): | |
| return item.get('name') or item.get('data') | |
| return str(item) | |
| def get_random_image_from_folder(folder_path): | |
| """Get a random image from the specified folder""" | |
| if not os.path.isdir(folder_path): | |
| return None, f"Error: {folder_path} is not a valid directory" | |
| # Get all image files in the folder | |
| image_files = [] | |
| for ext in ('*.jpg', '*.jpeg', '*.png', '*.bmp', '*.webp'): | |
| image_files.extend(glob.glob(os.path.join(folder_path, ext))) | |
| for ext in ('*.JPG', '*.JPEG', '*.PNG', '*.BMP', '*.WEBP'): | |
| image_files.extend(glob.glob(os.path.join(folder_path, ext))) | |
| if not image_files: | |
| return None, f"Error: No image files found in {folder_path}" | |
| # Select a random image | |
| random_image = random.choice(image_files) | |
| return random_image, f"Selected: {os.path.basename(random_image)}" | |
| def resize_image_keeping_aspect_ratio(image_path, max_width, max_height): | |
| """Resize image keeping aspect ratio and ensuring dimensions are divisible by 16""" | |
| try: | |
| img = Image.open(image_path) | |
| width, height = img.size | |
| # Calculate aspect ratio | |
| aspect_ratio = width / height | |
| # Calculate new dimensions while maintaining aspect ratio | |
| if width > height: | |
| new_width = min(max_width, width) | |
| new_height = int(new_width / aspect_ratio) | |
| else: | |
| new_height = min(max_height, height) | |
| new_width = int(new_height * aspect_ratio) | |
| # Make dimensions divisible by 16 | |
| new_width = math.floor(new_width / 16) * 16 | |
| new_height = math.floor(new_height / 16) * 16 | |
| # Ensure minimum size | |
| new_width = max(16, new_width) | |
| new_height = max(16, new_height) | |
| # Resize image | |
| resized_img = img.resize((new_width, new_height), Image.LANCZOS) | |
| # Save to temporary file | |
| temp_path = f"temp_resized_{os.path.basename(image_path)}" | |
| resized_img.save(temp_path) | |
| return temp_path, (new_width, new_height) | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| # Function to process a batch of images from a folder | |
| def batch_handler( | |
| use_random, | |
| prompt, negative_prompt, | |
| width, height, | |
| video_length, fps, infer_steps, | |
| seed, flow_shift, guidance_scale, embedded_cfg_scale, | |
| batch_size, input_folder_path, | |
| dit_folder, model, vae, te1, te2, save_path, output_type, attn_mode, | |
| block_swap, exclude_single_blocks, use_split_attn, use_fp8, split_uncond, | |
| lora_folder, *lora_params | |
| ): | |
| """Handle both folder-based batch processing and regular batch processing""" | |
| global stop_event | |
| # Check if this is a SkyReels model that needs special handling | |
| is_skyreels = "skyreels" in model.lower() | |
| is_skyreels_i2v = is_skyreels and "i2v" in model.lower() | |
| if use_random: | |
| # Random image from folder mode | |
| stop_event.clear() | |
| all_videos = [] | |
| progress_text = "Starting generation..." | |
| yield [], "Preparing...", progress_text | |
| for i in range(batch_size): | |
| if stop_event.is_set(): | |
| break | |
| batch_text = f"Generating video {i + 1} of {batch_size}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Get random image from folder | |
| random_image, status = get_random_image_from_folder(input_folder_path) | |
| if random_image is None: | |
| yield all_videos, f"Error in batch {i+1}: {status}", "" | |
| continue | |
| # Resize image | |
| resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) | |
| if resized_image is None: | |
| yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" | |
| continue | |
| # If we have dimensions, update them | |
| local_width, local_height = width, height | |
| if isinstance(size_info, tuple): | |
| local_width, local_height = size_info | |
| progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height}" | |
| else: | |
| progress_text = f"Using image: {os.path.basename(random_image)}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Calculate seed for this batch item | |
| current_seed = seed | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| elif batch_size > 1: | |
| current_seed = seed + i | |
| # Process the image | |
| # For SkyReels models, we need to create a command with dit_in_channels=32 | |
| if is_skyreels_i2v: | |
| env = os.environ.copy() | |
| env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model | |
| # Extract parameters from lora_params | |
| num_lora_weights = 4 | |
| lora_weights = lora_params[:num_lora_weights] | |
| lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
| cmd = [ | |
| sys.executable, | |
| "hv_generate_video.py", | |
| "--dit", model_path, | |
| "--vae", vae, | |
| "--text_encoder1", te1, | |
| "--text_encoder2", te2, | |
| "--prompt", prompt, | |
| "--video_size", str(local_height), str(local_width), | |
| "--video_length", str(video_length), | |
| "--fps", str(fps), | |
| "--infer_steps", str(infer_steps), | |
| "--save_path", save_path, | |
| "--seed", str(current_seed), | |
| "--flow_shift", str(flow_shift), | |
| "--embedded_cfg_scale", str(embedded_cfg_scale), | |
| "--output_type", output_type, | |
| "--attn_mode", attn_mode, | |
| "--blocks_to_swap", str(block_swap), | |
| "--fp8_llm", | |
| "--vae_chunk_size", "32", | |
| "--vae_spatial_tile_sample_min_size", "128", | |
| "--dit_in_channels", "32", # This is crucial for SkyReels i2v | |
| "--image_path", resized_image # Pass the image directly | |
| ] | |
| if use_fp8: | |
| cmd.append("--fp8") | |
| if split_uncond: | |
| cmd.append("--split_uncond") | |
| if use_split_attn: | |
| cmd.append("--split_attn") | |
| if exclude_single_blocks: | |
| cmd.append("--exclude_single_blocks") | |
| if negative_prompt: | |
| cmd.extend(["--negative_prompt", negative_prompt]) | |
| if guidance_scale is not None: | |
| cmd.extend(["--guidance_scale", str(guidance_scale)]) | |
| # Add LoRA weights and multipliers if provided | |
| valid_loras = [] | |
| for weight, mult in zip(lora_weights, lora_multipliers): | |
| if weight and weight != "None": | |
| valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
| if valid_loras: | |
| weights = [weight for weight, _ in valid_loras] | |
| multipliers = [str(mult) for _, mult in valid_loras] | |
| cmd.extend(["--lora_weight"] + weights) | |
| cmd.extend(["--lora_multiplier"] + multipliers) | |
| print(f"Running command: {' '.join(cmd)}") | |
| # Run the process | |
| p = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| env=env, | |
| text=True, | |
| encoding='utf-8', | |
| errors='replace', | |
| bufsize=1 | |
| ) | |
| while True: | |
| if stop_event.is_set(): | |
| p.terminate() | |
| p.wait() | |
| yield all_videos, "Generation stopped by user.", "" | |
| return | |
| line = p.stdout.readline() | |
| if not line: | |
| if p.poll() is not None: | |
| break | |
| continue | |
| print(line, end='') | |
| if '|' in line and '%' in line and '[' in line and ']' in line: | |
| yield all_videos.copy(), f"Processing video {i+1} (seed: {current_seed})", line.strip() | |
| p.stdout.close() | |
| p.wait() | |
| # Collect generated video | |
| save_path_abs = os.path.abspath(save_path) | |
| if os.path.exists(save_path_abs): | |
| all_videos_files = sorted( | |
| [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
| key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
| reverse=True | |
| ) | |
| matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] | |
| if matching_videos: | |
| video_path = os.path.join(save_path_abs, matching_videos[0]) | |
| all_videos.append((str(video_path), f"Seed: {current_seed}")) | |
| else: | |
| # For non-SkyReels models, use the regular process_single_video function | |
| num_lora_weights = 4 | |
| lora_weights = lora_params[:num_lora_weights] | |
| lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
| single_video_args = [ | |
| prompt, local_width, local_height, 1, video_length, fps, infer_steps, | |
| current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, embedded_cfg_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder | |
| ] | |
| single_video_args.extend(lora_weights) | |
| single_video_args.extend(lora_multipliers) | |
| single_video_args.extend([None, resized_image, None, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) | |
| for videos, status, progress in process_single_video(*single_video_args): | |
| if videos: | |
| all_videos.extend(videos) | |
| yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
| # Clean up temporary file | |
| try: | |
| if os.path.exists(resized_image): | |
| os.remove(resized_image) | |
| except: | |
| pass | |
| # Clear CUDA cache between generations | |
| clear_cuda_cache() | |
| time.sleep(0.5) | |
| yield all_videos, "Batch complete", "" | |
| else: | |
| # Regular image input - this is the part we need to fix | |
| # When a SkyReels I2V model is used, we need to use the direct command approach | |
| # with dit_in_channels=32 explicitly specified, just like in the folder processing branch | |
| if is_skyreels_i2v: | |
| stop_event.clear() | |
| all_videos = [] | |
| progress_text = "Starting generation..." | |
| yield [], "Preparing...", progress_text | |
| # Extract lora parameters | |
| num_lora_weights = 4 | |
| lora_weights = lora_params[:num_lora_weights] | |
| lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
| extra_args = list(lora_params[num_lora_weights*2:]) if len(lora_params) > num_lora_weights*2 else [] | |
| # Print extra_args for debugging | |
| print(f"Extra args: {extra_args}") | |
| # Get input image path from extra args - this is where we need to fix | |
| # In skyreels_generate_btn.click, we're passing skyreels_input which | |
| # should be the image path | |
| image_path = None | |
| if len(extra_args) > 0 and extra_args[0] is not None: | |
| image_path = extra_args[0] | |
| print(f"Image path found in extra_args[0]: {image_path}") | |
| # If we still don't have an image path, this is a problem | |
| if not image_path: | |
| # Let's try to debug what's happening - in the future, you can remove these | |
| # debug prints once everything works correctly | |
| print("No image path found in extra_args[0]") | |
| print(f"Full lora_params: {lora_params}") | |
| yield [], "Error: No input image provided", "An input image is required for SkyReels I2V models" | |
| return | |
| for i in range(batch_size): | |
| if stop_event.is_set(): | |
| yield all_videos, "Generation stopped by user", "" | |
| return | |
| # Calculate seed for this batch item | |
| current_seed = seed | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| elif batch_size > 1: | |
| current_seed = seed + i | |
| batch_text = f"Generating video {i + 1} of {batch_size}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Set up environment | |
| env = os.environ.copy() | |
| env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model | |
| # Build the command with dit_in_channels=32 | |
| cmd = [ | |
| sys.executable, | |
| "hv_generate_video.py", | |
| "--dit", model_path, | |
| "--vae", vae, | |
| "--text_encoder1", te1, | |
| "--text_encoder2", te2, | |
| "--prompt", prompt, | |
| "--video_size", str(height), str(width), | |
| "--video_length", str(video_length), | |
| "--fps", str(fps), | |
| "--infer_steps", str(infer_steps), | |
| "--save_path", save_path, | |
| "--seed", str(current_seed), | |
| "--flow_shift", str(flow_shift), | |
| "--embedded_cfg_scale", str(embedded_cfg_scale), | |
| "--output_type", output_type, | |
| "--attn_mode", attn_mode, | |
| "--blocks_to_swap", str(block_swap), | |
| "--fp8_llm", | |
| "--vae_chunk_size", "32", | |
| "--vae_spatial_tile_sample_min_size", "128", | |
| "--dit_in_channels", "32", # This is crucial for SkyReels i2v | |
| "--image_path", image_path | |
| ] | |
| if use_fp8: | |
| cmd.append("--fp8") | |
| if split_uncond: | |
| cmd.append("--split_uncond") | |
| if use_split_attn: | |
| cmd.append("--split_attn") | |
| if exclude_single_blocks: | |
| cmd.append("--exclude_single_blocks") | |
| if negative_prompt: | |
| cmd.extend(["--negative_prompt", negative_prompt]) | |
| if guidance_scale is not None: | |
| cmd.extend(["--guidance_scale", str(guidance_scale)]) | |
| # Add LoRA weights and multipliers if provided | |
| valid_loras = [] | |
| for weight, mult in zip(lora_weights, lora_multipliers): | |
| if weight and weight != "None": | |
| valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
| if valid_loras: | |
| weights = [weight for weight, _ in valid_loras] | |
| multipliers = [str(mult) for _, mult in valid_loras] | |
| cmd.extend(["--lora_weight"] + weights) | |
| cmd.extend(["--lora_multiplier"] + multipliers) | |
| print(f"Running command: {' '.join(cmd)}") | |
| # Run the process | |
| p = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| env=env, | |
| text=True, | |
| encoding='utf-8', | |
| errors='replace', | |
| bufsize=1 | |
| ) | |
| while True: | |
| if stop_event.is_set(): | |
| p.terminate() | |
| p.wait() | |
| yield all_videos, "Generation stopped by user.", "" | |
| return | |
| line = p.stdout.readline() | |
| if not line: | |
| if p.poll() is not None: | |
| break | |
| continue | |
| print(line, end='') | |
| if '|' in line and '%' in line and '[' in line and ']' in line: | |
| yield all_videos.copy(), f"Processing (seed: {current_seed})", line.strip() | |
| p.stdout.close() | |
| p.wait() | |
| # Collect generated video | |
| save_path_abs = os.path.abspath(save_path) | |
| if os.path.exists(save_path_abs): | |
| all_videos_files = sorted( | |
| [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
| key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
| reverse=True | |
| ) | |
| matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] | |
| if matching_videos: | |
| video_path = os.path.join(save_path_abs, matching_videos[0]) | |
| all_videos.append((str(video_path), f"Seed: {current_seed}")) | |
| # Clear CUDA cache between generations | |
| clear_cuda_cache() | |
| time.sleep(0.5) | |
| yield all_videos, "Batch complete", "" | |
| else: | |
| # For regular non-SkyReels models, use the original process_batch function | |
| regular_args = [ | |
| prompt, width, height, batch_size, video_length, fps, infer_steps, | |
| seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, guidance_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder | |
| ] | |
| yield from process_batch(*(regular_args + list(lora_params))) | |
| def get_dit_models(dit_folder: str) -> List[str]: | |
| """Get list of available DiT models in the specified folder""" | |
| if not os.path.exists(dit_folder): | |
| return ["mp_rank_00_model_states.pt"] | |
| models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] | |
| models.sort(key=str.lower) | |
| return models if models else ["mp_rank_00_model_states.pt"] | |
| def update_dit_and_lora_dropdowns(dit_folder: str, lora_folder: str, *current_values) -> List[gr.update]: | |
| """Update both DiT and LoRA dropdowns""" | |
| # Get model lists | |
| dit_models = get_dit_models(dit_folder) | |
| lora_choices = get_lora_options(lora_folder) | |
| # Current values processing | |
| dit_value = current_values[0] | |
| if dit_value not in dit_models: | |
| dit_value = dit_models[0] if dit_models else None | |
| weights = current_values[1:5] | |
| multipliers = current_values[5:9] | |
| results = [gr.update(choices=dit_models, value=dit_value)] | |
| # Add LoRA updates | |
| for i in range(4): | |
| weight = weights[i] if i < len(weights) else "None" | |
| multiplier = multipliers[i] if i < len(multipliers) else 1.0 | |
| if weight not in lora_choices: | |
| weight = "None" | |
| results.extend([ | |
| gr.update(choices=lora_choices, value=weight), | |
| gr.update(value=multiplier) | |
| ]) | |
| return results | |
| def extract_video_metadata(video_path: str) -> Dict: | |
| """Extract metadata from video file using ffprobe.""" | |
| cmd = [ | |
| 'ffprobe', | |
| '-v', 'quiet', | |
| '-print_format', 'json', | |
| '-show_format', | |
| video_path | |
| ] | |
| try: | |
| result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) | |
| metadata = json.loads(result.stdout.decode('utf-8')) | |
| if 'format' in metadata and 'tags' in metadata['format']: | |
| comment = metadata['format']['tags'].get('comment', '{}') | |
| return json.loads(comment) | |
| return {} | |
| except Exception as e: | |
| print(f"Metadata extraction failed: {str(e)}") | |
| return {} | |
| def create_parameter_transfer_map(metadata: Dict, target_tab: str) -> Dict: | |
| """Map metadata parameters to Gradio components for different tabs""" | |
| mapping = { | |
| 'common': { | |
| 'prompt': ('prompt', 'v2v_prompt'), | |
| 'width': ('width', 'v2v_width'), | |
| 'height': ('height', 'v2v_height'), | |
| 'batch_size': ('batch_size', 'v2v_batch_size'), | |
| 'video_length': ('video_length', 'v2v_video_length'), | |
| 'fps': ('fps', 'v2v_fps'), | |
| 'infer_steps': ('infer_steps', 'v2v_infer_steps'), | |
| 'seed': ('seed', 'v2v_seed'), | |
| 'model': ('model', 'v2v_model'), | |
| 'vae': ('vae', 'v2v_vae'), | |
| 'te1': ('te1', 'v2v_te1'), | |
| 'te2': ('te2', 'v2v_te2'), | |
| 'save_path': ('save_path', 'v2v_save_path'), | |
| 'flow_shift': ('flow_shift', 'v2v_flow_shift'), | |
| 'cfg_scale': ('cfg_scale', 'v2v_cfg_scale'), | |
| 'output_type': ('output_type', 'v2v_output_type'), | |
| 'attn_mode': ('attn_mode', 'v2v_attn_mode'), | |
| 'block_swap': ('block_swap', 'v2v_block_swap') | |
| }, | |
| 'lora': { | |
| 'lora_weights': [(f'lora{i+1}', f'v2v_lora_weights[{i}]') for i in range(4)], | |
| 'lora_multipliers': [(f'lora{i+1}_multiplier', f'v2v_lora_multipliers[{i}]') for i in range(4)] | |
| } | |
| } | |
| results = {} | |
| for param, value in metadata.items(): | |
| # Handle common parameters | |
| if param in mapping['common']: | |
| target = mapping['common'][param][0 if target_tab == 't2v' else 1] | |
| results[target] = value | |
| # Handle LoRA parameters | |
| if param == 'lora_weights': | |
| for i, weight in enumerate(value[:4]): | |
| target = mapping['lora']['lora_weights'][i][1 if target_tab == 'v2v' else 0] | |
| results[target] = weight | |
| if param == 'lora_multipliers': | |
| for i, mult in enumerate(value[:4]): | |
| target = mapping['lora']['lora_multipliers'][i][1 if target_tab == 'v2v' else 0] | |
| results[target] = float(mult) | |
| return results | |
| def add_metadata_to_video(video_path: str, parameters: dict) -> None: | |
| """Add generation parameters to video metadata using ffmpeg.""" | |
| import json | |
| import subprocess | |
| # Convert parameters to JSON string | |
| params_json = json.dumps(parameters, indent=2) | |
| # Temporary output path | |
| temp_path = video_path.replace(".mp4", "_temp.mp4") | |
| # FFmpeg command to add metadata without re-encoding | |
| cmd = [ | |
| 'ffmpeg', | |
| '-i', video_path, | |
| '-metadata', f'comment={params_json}', | |
| '-codec', 'copy', | |
| temp_path | |
| ] | |
| try: | |
| # Execute FFmpeg command | |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| # Replace original file with the metadata-enhanced version | |
| os.replace(temp_path, video_path) | |
| except subprocess.CalledProcessError as e: | |
| print(f"Failed to add metadata: {e.stderr.decode()}") | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| def count_prompt_tokens(prompt: str) -> int: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| tokens = enc.encode(prompt) | |
| return len(tokens) | |
| def get_lora_options(lora_folder: str = "lora") -> List[str]: | |
| if not os.path.exists(lora_folder): | |
| return ["None"] | |
| lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')] | |
| lora_files.sort(key=str.lower) | |
| return ["None"] + lora_files | |
| def update_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: | |
| new_choices = get_lora_options(lora_folder) | |
| weights = current_values[:4] | |
| multipliers = current_values[4:8] | |
| results = [] | |
| for i in range(4): | |
| weight = weights[i] if i < len(weights) else "None" | |
| multiplier = multipliers[i] if i < len(multipliers) else 1.0 | |
| if weight not in new_choices: | |
| weight = "None" | |
| results.extend([ | |
| gr.update(choices=new_choices, value=weight), | |
| gr.update(value=multiplier) | |
| ]) | |
| return results | |
| def send_to_v2v(evt: gr.SelectData, gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str, int]: | |
| """Transfer selected video and prompt to Video2Video tab""" | |
| if not gallery or evt.index >= len(gallery): | |
| return None, "", selected_index.value | |
| selected_item = gallery[evt.index] | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| # Final cleanup for Gradio Video component | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| # Update the selected index | |
| selected_index.value = evt.index | |
| return str(video_path), prompt, evt.index | |
| def send_selected_to_v2v(gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str]: | |
| """Send the currently selected video to V2V tab""" | |
| if not gallery or selected_index.value is None or selected_index.value >= len(gallery): | |
| return None, "" | |
| selected_item = gallery[selected_index.value] | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| # Final cleanup for Gradio Video component | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| return str(video_path), prompt | |
| def clear_cuda_cache(): | |
| """Clear CUDA cache if available""" | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Optional: synchronize to ensure cache is cleared | |
| torch.cuda.synchronize() | |
| def wanx_batch_handler( | |
| use_random, | |
| prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| video_length, | |
| fps, | |
| infer_steps, | |
| flow_shift, | |
| guidance_scale, | |
| seed, | |
| batch_size, | |
| input_folder_path, | |
| task, | |
| dit_path, | |
| vae_path, | |
| t5_path, | |
| clip_path, | |
| save_path, | |
| output_type, | |
| sample_solver, | |
| exclude_single_blocks, | |
| attn_mode, | |
| block_swap, | |
| fp8, | |
| fp8_t5, | |
| lora_folder, | |
| *lora_params | |
| ): | |
| """Handle both folder-based batch processing and regular processing for WanX""" | |
| global stop_event | |
| if use_random: | |
| # Random image from folder mode | |
| stop_event.clear() | |
| all_videos = [] | |
| progress_text = "Starting generation..." | |
| yield [], "Preparing...", progress_text | |
| # Ensure batch_size is treated as an integer | |
| batch_size = int(batch_size) | |
| # Process each item in the batch separately | |
| for i in range(batch_size): | |
| if stop_event.is_set(): | |
| yield all_videos, "Generation stopped by user", "" | |
| return | |
| batch_text = f"Generating video {i + 1} of {batch_size}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Get random image from folder | |
| random_image, status = get_random_image_from_folder(input_folder_path) | |
| if random_image is None: | |
| yield all_videos, f"Error in batch {i+1}: {status}", "" | |
| continue | |
| # Resize image | |
| resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) | |
| if resized_image is None: | |
| yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" | |
| continue | |
| # Use the dimensions returned from the resize function | |
| local_width, local_height = width, height # Default fallback | |
| if isinstance(size_info, tuple): | |
| local_width, local_height = size_info | |
| progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height} (maintaining aspect ratio)" | |
| else: | |
| progress_text = f"Using image: {os.path.basename(random_image)}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Calculate seed for this batch item | |
| current_seed = seed | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| elif batch_size > 1: | |
| current_seed = seed + i | |
| # Extract LoRA weights and multipliers | |
| num_lora_weights = 4 | |
| lora_weights = lora_params[:num_lora_weights] | |
| lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
| # Generate video for this image - one at a time | |
| for videos, status, progress in wanx_generate_video( | |
| prompt, | |
| negative_prompt, | |
| resized_image, | |
| local_width, | |
| local_height, | |
| video_length, | |
| fps, | |
| infer_steps, | |
| flow_shift, | |
| guidance_scale, | |
| current_seed, | |
| task, | |
| dit_path, | |
| vae_path, | |
| t5_path, | |
| clip_path, | |
| save_path, | |
| output_type, | |
| sample_solver, | |
| exclude_single_blocks, | |
| attn_mode, | |
| block_swap, | |
| fp8, | |
| fp8_t5, | |
| lora_folder, | |
| *lora_weights, | |
| *lora_multipliers | |
| ): | |
| if videos: | |
| all_videos.extend(videos) | |
| yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
| # Clean up temporary file | |
| try: | |
| if os.path.exists(resized_image): | |
| os.remove(resized_image) | |
| except: | |
| pass | |
| # Clear CUDA cache between generations | |
| clear_cuda_cache() | |
| time.sleep(0.5) | |
| yield all_videos, "Batch complete", "" | |
| else: | |
| # For non-random mode, if batch_size > 1, we need to process multiple times | |
| # with the same input image but different seeds | |
| if int(batch_size) > 1: | |
| stop_event.clear() | |
| all_videos = [] | |
| progress_text = "Starting generation..." | |
| yield [], "Preparing...", progress_text | |
| # Extract LoRA weights and multipliers and input image | |
| num_lora_weights = 4 | |
| lora_weights = lora_params[:num_lora_weights] | |
| lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
| input_image = lora_params[num_lora_weights*2] if len(lora_params) > num_lora_weights*2 else None | |
| # Process each batch item | |
| for i in range(int(batch_size)): | |
| if stop_event.is_set(): | |
| yield all_videos, "Generation stopped by user", "" | |
| return | |
| # Calculate seed for this batch item | |
| current_seed = seed | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| elif batch_size > 1: | |
| current_seed = seed + i | |
| batch_text = f"Generating video {i + 1} of {batch_size}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Generate a single video with the current seed | |
| for videos, status, progress in wanx_generate_video( | |
| prompt, | |
| negative_prompt, | |
| input_image, | |
| width, | |
| height, | |
| video_length, | |
| fps, | |
| infer_steps, | |
| flow_shift, | |
| guidance_scale, | |
| current_seed, | |
| task, | |
| dit_path, | |
| vae_path, | |
| t5_path, | |
| clip_path, | |
| save_path, | |
| output_type, | |
| sample_solver, | |
| exclude_single_blocks, | |
| attn_mode, | |
| block_swap, | |
| fp8, | |
| fp8_t5, | |
| lora_folder, | |
| *lora_weights, | |
| *lora_multipliers | |
| ): | |
| if videos: | |
| all_videos.extend(videos) | |
| yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
| # Clear CUDA cache between generations | |
| clear_cuda_cache() | |
| time.sleep(0.5) | |
| yield all_videos, "Batch complete", "" | |
| else: | |
| # Single image, single generation - use existing function | |
| num_lora_weights = 4 | |
| lora_weights = lora_params[:num_lora_weights] | |
| lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
| input_image = lora_params[num_lora_weights*2] if len(lora_params) > num_lora_weights*2 else None | |
| yield from wanx_generate_video( | |
| prompt, | |
| negative_prompt, | |
| input_image, | |
| width, | |
| height, | |
| video_length, | |
| fps, | |
| infer_steps, | |
| flow_shift, | |
| guidance_scale, | |
| seed, | |
| task, | |
| dit_path, | |
| vae_path, | |
| t5_path, | |
| clip_path, | |
| save_path, | |
| output_type, | |
| sample_solver, | |
| exclude_single_blocks, | |
| attn_mode, | |
| block_swap, | |
| fp8, | |
| fp8_t5, | |
| lora_folder, | |
| *lora_weights, | |
| *lora_multipliers | |
| ) | |
| def process_single_video( | |
| prompt: str, | |
| width: int, | |
| height: int, | |
| batch_size: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| dit_folder: str, | |
| model: str, | |
| vae: str, | |
| te1: str, | |
| te2: str, | |
| save_path: str, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| output_type: str, | |
| attn_mode: str, | |
| block_swap: int, | |
| exclude_single_blocks: bool, | |
| use_split_attn: bool, | |
| lora_folder: str, | |
| lora1: str = "", | |
| lora2: str = "", | |
| lora3: str = "", | |
| lora4: str = "", | |
| lora1_multiplier: float = 1.0, | |
| lora2_multiplier: float = 1.0, | |
| lora3_multiplier: float = 1.0, | |
| lora4_multiplier: float = 1.0, | |
| video_path: Optional[str] = None, | |
| image_path: Optional[str] = None, | |
| strength: Optional[float] = None, | |
| negative_prompt: Optional[str] = None, | |
| embedded_cfg_scale: Optional[float] = None, | |
| split_uncond: Optional[bool] = None, | |
| guidance_scale: Optional[float] = None, | |
| use_fp8: bool = True | |
| ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
| """Generate a single video with the given parameters""" | |
| global stop_event | |
| if stop_event.is_set(): | |
| yield [], "", "" | |
| return | |
| # Determine if this is a SkyReels model and what type | |
| is_skyreels = "skyreels" in model.lower() | |
| is_skyreels_i2v = is_skyreels and "i2v" in model.lower() | |
| is_skyreels_t2v = is_skyreels and "t2v" in model.lower() | |
| if is_skyreels: | |
| # Force certain parameters for SkyReels | |
| if negative_prompt is None: | |
| negative_prompt = "" | |
| if embedded_cfg_scale is None: | |
| embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels | |
| if split_uncond is None: | |
| split_uncond = True | |
| if guidance_scale is None: | |
| guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided | |
| # Determine the input channels based on model type | |
| if is_skyreels_i2v: | |
| dit_in_channels = 32 # SkyReels I2V uses 32 channels | |
| else: | |
| dit_in_channels = 16 # SkyReels T2V uses 16 channels (same as regular models) | |
| else: | |
| dit_in_channels = 16 # Regular Hunyuan models use 16 channels | |
| embedded_cfg_scale = cfg_scale | |
| if os.path.isabs(model): | |
| model_path = model | |
| else: | |
| model_path = os.path.normpath(os.path.join(dit_folder, model)) | |
| env = os.environ.copy() | |
| env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| env["BATCH_RUN_ID"] = f"{time.time()}" | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| else: | |
| batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) | |
| if batch_size > 1: # Only modify seed for batch generation | |
| current_seed = (seed + batch_id * 100003) % (2**32) | |
| else: | |
| current_seed = seed | |
| clear_cuda_cache() | |
| command = [ | |
| sys.executable, | |
| "hv_generate_video.py", | |
| "--dit", model_path, | |
| "--vae", vae, | |
| "--text_encoder1", te1, | |
| "--text_encoder2", te2, | |
| "--prompt", prompt, | |
| "--video_size", str(height), str(width), | |
| "--video_length", str(video_length), | |
| "--fps", str(fps), | |
| "--infer_steps", str(infer_steps), | |
| "--save_path", save_path, | |
| "--seed", str(current_seed), | |
| "--flow_shift", str(flow_shift), | |
| "--embedded_cfg_scale", str(cfg_scale), | |
| "--output_type", output_type, | |
| "--attn_mode", attn_mode, | |
| "--blocks_to_swap", str(block_swap), | |
| "--fp8_llm", | |
| "--vae_chunk_size", "32", | |
| "--vae_spatial_tile_sample_min_size", "128" | |
| ] | |
| if use_fp8: | |
| command.append("--fp8") | |
| # Add negative prompt and embedded cfg scale for SkyReels | |
| if is_skyreels: | |
| command.extend(["--dit_in_channels", str(dit_in_channels)]) | |
| command.extend(["--guidance_scale", str(guidance_scale)]) | |
| if negative_prompt: | |
| command.extend(["--negative_prompt", negative_prompt]) | |
| if split_uncond: | |
| command.append("--split_uncond") | |
| # Add LoRA weights and multipliers if provided | |
| valid_loras = [] | |
| for weight, mult in zip([lora1, lora2, lora3, lora4], | |
| [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): | |
| if weight and weight != "None": | |
| valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
| if valid_loras: | |
| weights = [weight for weight, _ in valid_loras] | |
| multipliers = [str(mult) for _, mult in valid_loras] | |
| command.extend(["--lora_weight"] + weights) | |
| command.extend(["--lora_multiplier"] + multipliers) | |
| if exclude_single_blocks: | |
| command.append("--exclude_single_blocks") | |
| if use_split_attn: | |
| command.append("--split_attn") | |
| # Handle input paths | |
| if video_path: | |
| command.extend(["--video_path", video_path]) | |
| if strength is not None: | |
| command.extend(["--strength", str(strength)]) | |
| elif image_path: | |
| command.extend(["--image_path", image_path]) | |
| # Only add strength parameter for non-SkyReels I2V models | |
| # SkyReels I2V doesn't use strength parameter for image-to-video generation | |
| if strength is not None and not is_skyreels_i2v: | |
| command.extend(["--strength", str(strength)]) | |
| print(f"{command}") | |
| p = subprocess.Popen( | |
| command, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| env=env, | |
| text=True, | |
| encoding='utf-8', | |
| errors='replace', | |
| bufsize=1 | |
| ) | |
| videos = [] | |
| while True: | |
| if stop_event.is_set(): | |
| p.terminate() | |
| p.wait() | |
| yield [], "", "Generation stopped by user." | |
| return | |
| line = p.stdout.readline() | |
| if not line: | |
| if p.poll() is not None: | |
| break | |
| continue | |
| print(line, end='') | |
| if '|' in line and '%' in line and '[' in line and ']' in line: | |
| yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() | |
| p.stdout.close() | |
| p.wait() | |
| clear_cuda_cache() | |
| time.sleep(0.5) | |
| # Collect generated video | |
| save_path_abs = os.path.abspath(save_path) | |
| if os.path.exists(save_path_abs): | |
| all_videos = sorted( | |
| [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
| key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
| reverse=True | |
| ) | |
| matching_videos = [v for v in all_videos if f"_{current_seed}" in v] | |
| if matching_videos: | |
| video_path = os.path.join(save_path_abs, matching_videos[0]) | |
| # Collect parameters for metadata | |
| parameters = { | |
| "prompt": prompt, | |
| "width": width, | |
| "height": height, | |
| "video_length": video_length, | |
| "fps": fps, | |
| "infer_steps": infer_steps, | |
| "seed": current_seed, | |
| "model": model, | |
| "vae": vae, | |
| "te1": te1, | |
| "te2": te2, | |
| "save_path": save_path, | |
| "flow_shift": flow_shift, | |
| "cfg_scale": cfg_scale, | |
| "output_type": output_type, | |
| "attn_mode": attn_mode, | |
| "block_swap": block_swap, | |
| "lora_weights": [lora1, lora2, lora3, lora4], | |
| "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], | |
| "input_video": video_path if video_path else None, | |
| "input_image": image_path if image_path else None, | |
| "strength": strength, | |
| "negative_prompt": negative_prompt if is_skyreels else None, | |
| "embedded_cfg_scale": embedded_cfg_scale if is_skyreels else None | |
| } | |
| add_metadata_to_video(video_path, parameters) | |
| videos.append((str(video_path), f"Seed: {current_seed}")) | |
| yield videos, f"Completed (seed: {current_seed})", "" | |
| # The issue is in the process_batch function, in the section that handles different input types | |
| # Here's the corrected version of that section: | |
| def process_batch( | |
| prompt: str, | |
| width: int, | |
| height: int, | |
| batch_size: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| dit_folder: str, | |
| model: str, | |
| vae: str, | |
| te1: str, | |
| te2: str, | |
| save_path: str, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| output_type: str, | |
| attn_mode: str, | |
| block_swap: int, | |
| exclude_single_blocks: bool, | |
| use_split_attn: bool, | |
| lora_folder: str, | |
| *args | |
| ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
| """Process a batch of videos using Gradio's queue""" | |
| global stop_event | |
| stop_event.clear() | |
| all_videos = [] | |
| progress_text = "Starting generation..." | |
| yield [], "Preparing...", progress_text | |
| # Extract additional arguments | |
| num_lora_weights = 4 | |
| lora_weights = args[:num_lora_weights] | |
| lora_multipliers = args[num_lora_weights:num_lora_weights*2] | |
| extra_args = args[num_lora_weights*2:] | |
| # Determine if this is a SkyReels model and what type | |
| is_skyreels = "skyreels" in model.lower() | |
| is_skyreels_i2v = is_skyreels and "i2v" in model.lower() | |
| is_skyreels_t2v = is_skyreels and "t2v" in model.lower() | |
| # Handle input paths and additional parameters | |
| input_path = extra_args[0] if extra_args else None | |
| strength = float(extra_args[1]) if len(extra_args) > 1 else None | |
| # Get use_fp8 flag (it should be the last parameter) | |
| use_fp8 = bool(extra_args[-1]) if extra_args and len(extra_args) >= 3 else True | |
| # Get SkyReels specific parameters if applicable | |
| if is_skyreels: | |
| # Always set embedded_cfg_scale to 1.0 for SkyReels models | |
| embedded_cfg_scale = 1.0 | |
| negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else "" | |
| # Use cfg_scale for guidance_scale parameter | |
| guidance_scale = float(extra_args[3]) if len(extra_args) > 3 and extra_args[3] is not None else cfg_scale | |
| split_uncond = True if len(extra_args) > 4 and extra_args[4] else False | |
| else: | |
| negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else None | |
| guidance_scale = cfg_scale | |
| embedded_cfg_scale = cfg_scale | |
| split_uncond = bool(extra_args[4]) if len(extra_args) > 4 else None | |
| for i in range(batch_size): | |
| if stop_event.is_set(): | |
| break | |
| batch_text = f"Generating video {i + 1} of {batch_size}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Handle different input types | |
| video_path = None | |
| image_path = None | |
| if input_path: | |
| # Check if it's an image file (common image extensions) | |
| is_image = False | |
| lower_path = input_path.lower() | |
| image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') | |
| is_image = any(lower_path.endswith(ext) for ext in image_extensions) | |
| # Only use image_path for SkyReels I2V models and actual image files | |
| if is_skyreels_i2v and is_image: | |
| image_path = input_path | |
| else: | |
| video_path = input_path | |
| # Prepare arguments for process_single_video | |
| single_video_args = [ | |
| prompt, width, height, batch_size, video_length, fps, infer_steps, | |
| seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder | |
| ] | |
| single_video_args.extend(lora_weights) | |
| single_video_args.extend(lora_multipliers) | |
| single_video_args.extend([video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) | |
| for videos, status, progress in process_single_video(*single_video_args): | |
| if videos: | |
| all_videos.extend(videos) | |
| yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
| yield all_videos, "Batch complete", "" | |
| def update_wanx_image_dimensions(image): | |
| """Update dimensions from uploaded image""" | |
| if image is None: | |
| return "", gr.update(value=832), gr.update(value=480) | |
| img = Image.open(image) | |
| w, h = img.size | |
| w = (w // 32) * 32 | |
| h = (h // 32) * 32 | |
| return f"{w}x{h}", w, h | |
| def calculate_wanx_width(height, original_dims): | |
| """Calculate width based on height maintaining aspect ratio""" | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_width = math.floor((height * aspect_ratio) / 32) * 32 | |
| return gr.update(value=new_width) | |
| def calculate_wanx_height(width, original_dims): | |
| """Calculate height based on width maintaining aspect ratio""" | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_height = math.floor((width / aspect_ratio) / 32) * 32 | |
| return gr.update(value=new_height) | |
| def update_wanx_from_scale(scale, original_dims): | |
| """Update dimensions based on scale percentage""" | |
| if not original_dims: | |
| return gr.update(), gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| new_w = math.floor((orig_w * scale / 100) / 32) * 32 | |
| new_h = math.floor((orig_h * scale / 100) / 32) * 32 | |
| return gr.update(value=new_w), gr.update(value=new_h) | |
| def recommend_wanx_flow_shift(width, height): | |
| """Get recommended flow shift value based on dimensions""" | |
| recommended_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 | |
| return gr.update(value=recommended_shift) | |
| def handle_wanx_gallery_select(evt: gr.SelectData, gallery) -> tuple: | |
| """Track selected index and video path when gallery item is clicked""" | |
| if gallery is None: | |
| return None, None | |
| if evt.index >= len(gallery): | |
| return None, None | |
| selected_item = gallery[evt.index] | |
| video_path = None | |
| # Extract the video path based on the item type | |
| if isinstance(selected_item, tuple): | |
| video_path = selected_item[0] | |
| elif isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| else: | |
| video_path = selected_item | |
| return evt.index, video_path | |
| def wanx_generate_video( | |
| prompt, | |
| negative_prompt, | |
| input_image, | |
| width, | |
| height, | |
| video_length, | |
| fps, | |
| infer_steps, | |
| flow_shift, | |
| guidance_scale, | |
| seed, | |
| task, | |
| dit_path, | |
| vae_path, | |
| t5_path, | |
| clip_path, | |
| save_path, | |
| output_type, | |
| sample_solver, | |
| exclude_single_blocks, | |
| attn_mode, | |
| block_swap, | |
| fp8, | |
| fp8_t5, | |
| lora_folder, | |
| lora1="None", | |
| lora2="None", | |
| lora3="None", | |
| lora4="None", | |
| lora1_multiplier=1.0, | |
| lora2_multiplier=1.0, | |
| lora3_multiplier=1.0, | |
| lora4_multiplier=1.0 | |
| ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
| """Generate video with WanX model (supports both i2v and t2v)""" | |
| global stop_event | |
| if stop_event.is_set(): | |
| yield [], "", "" | |
| return | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| else: | |
| current_seed = seed | |
| # Check if we need input image (required for i2v, not for t2v) | |
| if "i2v" in task and not input_image: | |
| yield [], "Error: No input image provided", "Please provide an input image for image-to-video generation" | |
| return | |
| # Prepare environment | |
| env = os.environ.copy() | |
| env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| clear_cuda_cache() | |
| command = [ | |
| sys.executable, | |
| "wan_generate_video.py", | |
| "--task", task, | |
| "--prompt", prompt, | |
| "--video_size", str(height), str(width), | |
| "--video_length", str(video_length), | |
| "--fps", str(fps), | |
| "--infer_steps", str(infer_steps), | |
| "--save_path", save_path, | |
| "--seed", str(current_seed), | |
| "--flow_shift", str(flow_shift), | |
| "--guidance_scale", str(guidance_scale), | |
| "--output_type", output_type, | |
| "--attn_mode", attn_mode, | |
| "--blocks_to_swap", str(block_swap), | |
| "--dit", dit_path, | |
| "--vae", vae_path, | |
| "--t5", t5_path, | |
| "--sample_solver", sample_solver | |
| ] | |
| # Add image path only for i2v task and if input image is provided | |
| if "i2v" in task and input_image: | |
| command.extend(["--image_path", input_image]) | |
| command.extend(["--clip", clip_path]) # CLIP is only needed for i2v | |
| if negative_prompt: | |
| command.extend(["--negative_prompt", negative_prompt]) | |
| if fp8: | |
| command.append("--fp8") | |
| if fp8_t5: | |
| command.append("--fp8_t5") | |
| if exclude_single_blocks: | |
| command.append("--exclude_single_blocks") | |
| # Add LoRA weights and multipliers if provided | |
| valid_loras = [] | |
| for weight, mult in zip([lora1, lora2, lora3, lora4], | |
| [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): | |
| if weight and weight != "None": | |
| valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
| if valid_loras: | |
| weights = [weight for weight, _ in valid_loras] | |
| multipliers = [str(mult) for _, mult in valid_loras] | |
| command.extend(["--lora_weight"] + weights) | |
| command.extend(["--lora_multiplier"] + multipliers) | |
| print(f"Running: {' '.join(command)}") | |
| p = subprocess.Popen( | |
| command, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| env=env, | |
| text=True, | |
| encoding='utf-8', | |
| errors='replace', | |
| bufsize=1 | |
| ) | |
| videos = [] | |
| while True: | |
| if stop_event.is_set(): | |
| p.terminate() | |
| p.wait() | |
| yield [], "", "Generation stopped by user." | |
| return | |
| line = p.stdout.readline() | |
| if not line: | |
| if p.poll() is not None: | |
| break | |
| continue | |
| print(line, end='') | |
| if '|' in line and '%' in line and '[' in line and ']' in line: | |
| yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() | |
| p.stdout.close() | |
| p.wait() | |
| clear_cuda_cache() | |
| time.sleep(0.5) | |
| # Collect generated video | |
| save_path_abs = os.path.abspath(save_path) | |
| if os.path.exists(save_path_abs): | |
| all_videos = sorted( | |
| [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
| key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
| reverse=True | |
| ) | |
| matching_videos = [v for v in all_videos if f"_{current_seed}" in v] | |
| if matching_videos: | |
| video_path = os.path.join(save_path_abs, matching_videos[0]) | |
| # Collect parameters for metadata | |
| parameters = { | |
| "prompt": prompt, | |
| "width": width, | |
| "height": height, | |
| "video_length": video_length, | |
| "fps": fps, | |
| "infer_steps": infer_steps, | |
| "seed": current_seed, | |
| "task": task, | |
| "flow_shift": flow_shift, | |
| "guidance_scale": guidance_scale, | |
| "output_type": output_type, | |
| "attn_mode": attn_mode, | |
| "block_swap": block_swap, | |
| "input_image": input_image if "i2v" in task else None | |
| } | |
| add_metadata_to_video(video_path, parameters) | |
| videos.append((str(video_path), f"Seed: {current_seed}")) | |
| yield videos, f"Completed (seed: {current_seed})", "" | |
| def send_wanx_to_v2v( | |
| gallery: list, | |
| prompt: str, | |
| selected_index: int, | |
| width: int, | |
| height: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| flow_shift: float, | |
| guidance_scale: float, | |
| negative_prompt: str | |
| ) -> Tuple: | |
| """Send the selected WanX video to Video2Video tab""" | |
| if gallery is None or not gallery: | |
| return (None, "", width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, guidance_scale, negative_prompt) | |
| # If no selection made but we have videos, use the first one | |
| if selected_index is None and len(gallery) > 0: | |
| selected_index = 0 | |
| if selected_index is None or selected_index >= len(gallery): | |
| return (None, "", width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, guidance_scale, negative_prompt) | |
| selected_item = gallery[selected_index] | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, tuple): | |
| video_path = selected_item[0] | |
| elif isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| else: | |
| video_path = selected_item | |
| # Clean up path for Video component | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| # Make sure it's a string | |
| video_path = str(video_path) | |
| return (video_path, prompt, width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, guidance_scale, negative_prompt) | |
| def wanx_generate_video_batch( | |
| prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| video_length, | |
| fps, | |
| infer_steps, | |
| flow_shift, | |
| guidance_scale, | |
| seed, | |
| task, | |
| dit_path, | |
| vae_path, | |
| t5_path, | |
| clip_path, | |
| save_path, | |
| output_type, | |
| sample_solver, | |
| exclude_single_blocks, | |
| attn_mode, | |
| block_swap, | |
| fp8, | |
| fp8_t5, | |
| lora_folder, | |
| lora1="None", | |
| lora2="None", | |
| lora3="None", | |
| lora4="None", | |
| lora1_multiplier=1.0, | |
| lora2_multiplier=1.0, | |
| lora3_multiplier=1.0, | |
| lora4_multiplier=1.0, | |
| batch_size=1, | |
| input_image=None # Make input_image optional and place it at the end | |
| ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
| """Generate videos with WanX with support for batches""" | |
| global stop_event | |
| stop_event.clear() | |
| all_videos = [] | |
| progress_text = "Starting generation..." | |
| yield [], "Preparing...", progress_text | |
| # Process each item in the batch | |
| for i in range(batch_size): | |
| if stop_event.is_set(): | |
| yield all_videos, "Generation stopped by user", "" | |
| return | |
| # Calculate seed for this batch item | |
| current_seed = seed | |
| if seed == -1: | |
| current_seed = random.randint(0, 2**32 - 1) | |
| elif batch_size > 1: | |
| current_seed = seed + i | |
| batch_text = f"Generating video {i + 1} of {batch_size}" | |
| yield all_videos.copy(), batch_text, progress_text | |
| # Generate a single video using the existing function | |
| for videos, status, progress in wanx_generate_video( | |
| prompt, negative_prompt, input_image, width, height, | |
| video_length, fps, infer_steps, flow_shift, guidance_scale, | |
| current_seed, task, dit_path, vae_path, t5_path, clip_path, | |
| save_path, output_type, sample_solver, exclude_single_blocks, | |
| attn_mode, block_swap, fp8, fp8_t5, | |
| lora_folder, | |
| lora1, | |
| lora2, | |
| lora3, | |
| lora4, | |
| lora1_multiplier, | |
| lora2_multiplier, | |
| lora3_multiplier, | |
| lora4_multiplier | |
| ): | |
| if videos: | |
| all_videos.extend(videos) | |
| yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
| yield all_videos, "Batch complete", "" | |
| def update_wanx_t2v_dimensions(size): | |
| """Update width and height based on selected size""" | |
| width, height = map(int, size.split('*')) | |
| return gr.update(value=width), gr.update(value=height) | |
| def handle_wanx_t2v_gallery_select(evt: gr.SelectData) -> int: | |
| """Track selected index when gallery item is clicked""" | |
| return evt.index | |
| def send_wanx_t2v_to_v2v( | |
| gallery, prompt, selected_index, width, height, video_length, | |
| fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt | |
| ) -> Tuple: | |
| """Send the selected WanX T2V video to Video2Video tab""" | |
| if not gallery or selected_index is None or selected_index >= len(gallery): | |
| return (None, "", width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, guidance_scale, negative_prompt) | |
| selected_item = gallery[selected_index] | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, guidance_scale, negative_prompt) | |
| def prepare_for_batch_extension(input_img, base_video, batch_size): | |
| """Prepare inputs for batch video extension""" | |
| if input_img is None: | |
| return None, None, batch_size, "No input image found", "" | |
| if base_video is None: | |
| return input_img, None, batch_size, "No base video selected for extension", "" | |
| return input_img, base_video, batch_size, "Preparing batch extension...", f"Will create {batch_size} variations of extended video" | |
| def concat_batch_videos(base_video_path, generated_videos, save_path, original_video_path=None): | |
| """Concatenate multiple generated videos with the base video""" | |
| if not base_video_path: | |
| return [], "No base video provided" | |
| if not generated_videos or len(generated_videos) == 0: | |
| return [], "No new videos generated" | |
| # Create output directory if it doesn't exist | |
| os.makedirs(save_path, exist_ok=True) | |
| # Track all extended videos | |
| extended_videos = [] | |
| # For each generated video, create an extended version | |
| for i, video_item in enumerate(generated_videos): | |
| try: | |
| # Extract video path from gallery item | |
| if isinstance(video_item, tuple): | |
| new_video_path = video_item[0] | |
| seed_info = video_item[1] if len(video_item) > 1 else "" | |
| elif isinstance(video_item, dict): | |
| new_video_path = video_item.get("name", video_item.get("data", None)) | |
| seed_info = "" | |
| else: | |
| new_video_path = video_item | |
| seed_info = "" | |
| if not new_video_path or not os.path.exists(new_video_path): | |
| print(f"Skipping missing video: {new_video_path}") | |
| continue | |
| # Create unique output filename | |
| timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
| # Extract seed from seed_info if available | |
| seed_match = re.search(r"Seed: (\d+)", seed_info) | |
| seed_part = f"_seed{seed_match.group(1)}" if seed_match else f"_{i}" | |
| output_filename = f"extended_{timestamp}{seed_part}_{Path(base_video_path).stem}.mp4" | |
| output_path = os.path.join(save_path, output_filename) | |
| # Create a temporary file list for ffmpeg | |
| list_file = os.path.join(save_path, f"temp_list_{i}.txt") | |
| with open(list_file, "w") as f: | |
| f.write(f"file '{os.path.abspath(base_video_path)}'\n") | |
| f.write(f"file '{os.path.abspath(new_video_path)}'\n") | |
| # Run ffmpeg concatenation | |
| command = [ | |
| "ffmpeg", | |
| "-f", "concat", | |
| "-safe", "0", | |
| "-i", list_file, | |
| "-c", "copy", | |
| output_path | |
| ] | |
| subprocess.run(command, check=True, capture_output=True) | |
| # Clean up temporary file | |
| if os.path.exists(list_file): | |
| os.remove(list_file) | |
| # Add to extended videos list if successful | |
| if os.path.exists(output_path): | |
| seed_display = f"Extended {seed_info}" if seed_info else f"Extended video #{i+1}" | |
| extended_videos.append((output_path, seed_display)) | |
| except Exception as e: | |
| print(f"Error processing video {i}: {str(e)}") | |
| if not extended_videos: | |
| return [], "Failed to create any extended videos" | |
| return extended_videos, f"Successfully created {len(extended_videos)} extended videos" | |
| def handle_extend_generation(base_video_path: str, new_videos: list, save_path: str, current_gallery: list) -> tuple: | |
| """Combine generated video with base video and update gallery""" | |
| if not base_video_path: | |
| return current_gallery, "Extend failed: No base video provided" | |
| if not new_videos: | |
| return current_gallery, "Extend failed: No new video generated" | |
| # Ensure save path exists | |
| os.makedirs(save_path, exist_ok=True) | |
| # Get the first video from new_videos (gallery item) | |
| new_video_path = new_videos[0][0] if isinstance(new_videos[0], tuple) else new_videos[0] | |
| # Create a unique output filename | |
| timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
| output_filename = f"extended_{timestamp}_{Path(base_video_path).stem}.mp4" | |
| output_path = str(Path(save_path) / output_filename) | |
| try: | |
| # Concatenate the videos using ffmpeg | |
| ( | |
| ffmpeg | |
| .input(base_video_path) | |
| .concat( | |
| ffmpeg.input(new_video_path) | |
| ) | |
| .output(output_path) | |
| .run(overwrite_output=True, quiet=True) | |
| ) | |
| # Create a new gallery entry with the combined video | |
| updated_gallery = [(output_path, f"Extended video: {Path(output_path).stem}")] | |
| return updated_gallery, f"Successfully extended video to {Path(output_path).name}" | |
| except Exception as e: | |
| print(f"Error extending video: {str(e)}") | |
| return current_gallery, f"Failed to extend video: {str(e)}" | |
| # UI setup | |
| with gr.Blocks( | |
| theme=themes.Default( | |
| primary_hue=colors.Color( | |
| name="custom", | |
| c50="#E6F0FF", | |
| c100="#CCE0FF", | |
| c200="#99C1FF", | |
| c300="#66A3FF", | |
| c400="#3384FF", | |
| c500="#0060df", # This is your main color | |
| c600="#0052C2", | |
| c700="#003D91", | |
| c800="#002961", | |
| c900="#001430", | |
| c950="#000A18" | |
| ) | |
| ), | |
| css=""" | |
| .gallery-item:first-child { border: 2px solid #4CAF50 !important; } | |
| .gallery-item:first-child:hover { border-color: #45a049 !important; } | |
| .green-btn { | |
| background: linear-gradient(to bottom right, #2ecc71, #27ae60) !important; | |
| color: white !important; | |
| border: none !important; | |
| } | |
| .green-btn:hover { | |
| background: linear-gradient(to bottom right, #27ae60, #219651) !important; | |
| } | |
| .refresh-btn { | |
| max-width: 40px !important; | |
| min-width: 40px !important; | |
| height: 40px !important; | |
| border-radius: 50% !important; | |
| padding: 0 !important; | |
| display: flex !important; | |
| align-items: center !important; | |
| justify-content: center !important; | |
| } | |
| """, | |
| ) as demo: | |
| # Add state for tracking selected video indices in both tabs | |
| selected_index = gr.State(value=None) # For Text to Video | |
| v2v_selected_index = gr.State(value=None) # For Video to Video | |
| params_state = gr.State() #New addition | |
| i2v_selected_index = gr.State(value=None) | |
| skyreels_selected_index = gr.State(value=None) | |
| wanx_i2v_selected_index = gr.State(value=None) | |
| extended_videos = gr.State(value=[]) | |
| wanx_base_video = gr.State(value=None) | |
| wanx_sharpest_frame_number = gr.State(value=None) | |
| wanx_sharpest_frame_path = gr.State(value=None) | |
| wanx_trimmed_video_path = gr.State(value=None) | |
| demo.load(None, None, None, js=""" | |
| () => { | |
| document.title = 'H1111'; | |
| function updateTitle(text) { | |
| if (text && text.trim()) { | |
| const progressMatch = text.match(/(\d+)%.*\[.*<(\d+:\d+),/); | |
| if (progressMatch) { | |
| const percentage = progressMatch[1]; | |
| const timeRemaining = progressMatch[2]; | |
| document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; | |
| } | |
| } | |
| } | |
| setTimeout(() => { | |
| const progressElements = document.querySelectorAll('textarea.scroll-hide'); | |
| progressElements.forEach(element => { | |
| if (element) { | |
| new MutationObserver(() => { | |
| updateTitle(element.value); | |
| }).observe(element, { | |
| attributes: true, | |
| childList: true, | |
| characterData: true | |
| }); | |
| } | |
| }); | |
| }, 1000); | |
| } | |
| """) | |
| with gr.Tabs() as tabs: | |
| # Text to Video Tab | |
| with gr.Tab(id=1, label="Hunyuan-t2v"): | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) | |
| with gr.Column(scale=1): | |
| token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| t2v_width = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Width") | |
| t2v_height = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Height") | |
| video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25, elem_id="my_special_slider") | |
| fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24, elem_id="my_special_slider") | |
| infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30, elem_id="my_special_slider") | |
| flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0, elem_id="my_special_slider") | |
| cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg Scale", value=7.0, elem_id="my_special_slider") | |
| with gr.Column(): | |
| with gr.Row(): | |
| video_output = gr.Gallery( | |
| label="Generated Videos (Click to select)", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_label=True, | |
| elem_id="gallery", | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| with gr.Row():send_t2v_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
| with gr.Row(): | |
| refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| lora_weights = [] | |
| lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
| model = gr.Dropdown( | |
| label="DiT Model", | |
| choices=get_dit_models("hunyuan"), | |
| value="mp_rank_00_model_states.pt", | |
| allow_custom_value=True, | |
| interactive=True | |
| ) | |
| vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
| te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
| te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
| save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
| use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
| attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
| #Image to Video Tab | |
| with gr.Tab(label="Hunyuan-i2v") as i2v_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| i2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) | |
| with gr.Column(scale=1): | |
| i2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| i2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| i2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| i2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| i2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| i2v_stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| i2v_input = gr.Image(label="Input Image", type="filepath") | |
| i2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") | |
| # Scale slider as percentage | |
| scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
| original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
| # Width and height inputs | |
| with gr.Row(): | |
| width = gr.Number(label="New Width", value=544, step=16) | |
| calc_height_btn = gr.Button("→") | |
| calc_width_btn = gr.Button("←") | |
| height = gr.Number(label="New Height", value=544, step=16) | |
| i2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) | |
| i2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) | |
| i2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
| i2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) | |
| i2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) | |
| with gr.Column(): | |
| i2v_output = gr.Gallery( | |
| label="Generated Videos (Click to select)", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_label=True, | |
| elem_id="gallery", | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| i2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
| # Add LoRA section for Image2Video | |
| i2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| i2v_lora_weights = [] | |
| i2v_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| i2v_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| i2v_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| i2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| i2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| i2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
| i2v_model = gr.Dropdown( | |
| label="DiT Model", | |
| choices=get_dit_models("hunyuan"), | |
| value="mp_rank_00_model_states.pt", | |
| allow_custom_value=True, | |
| interactive=True | |
| ) | |
| i2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
| i2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
| i2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
| i2v_save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| i2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| i2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| i2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
| i2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
| i2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| i2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
| # Video to Video Tab | |
| with gr.Tab(id=2, label="Hunyuan-v2v") as v2v_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| v2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) | |
| v2v_negative_prompt = gr.Textbox( | |
| scale=3, | |
| label="Negative Prompt (for SkyReels models)", | |
| value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", | |
| lines=3 | |
| ) | |
| with gr.Column(scale=1): | |
| v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| v2v_stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| v2v_input = gr.Video(label="Input Video", format="mp4") | |
| v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") | |
| v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
| v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
| # Width and Height Inputs | |
| with gr.Row(): | |
| v2v_width = gr.Number(label="New Width", value=544, step=16) | |
| v2v_calc_height_btn = gr.Button("→") | |
| v2v_calc_width_btn = gr.Button("←") | |
| v2v_height = gr.Number(label="New Height", value=544, step=16) | |
| v2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) | |
| v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) | |
| v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
| v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) | |
| v2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) | |
| with gr.Column(): | |
| v2v_output = gr.Gallery( | |
| label="Generated Videos", | |
| columns=[1], | |
| rows=[1], | |
| object_fit="contain", | |
| height="auto" | |
| ) | |
| v2v_send_to_input_btn = gr.Button("Send Selected to Input") # New button | |
| v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| v2v_lora_weights = [] | |
| v2v_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| v2v_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| v2v_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
| v2v_model = gr.Dropdown( | |
| label="DiT Model", | |
| choices=get_dit_models("hunyuan"), | |
| value="mp_rank_00_model_states.pt", | |
| allow_custom_value=True, | |
| interactive=True | |
| ) | |
| v2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
| v2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
| v2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
| v2v_save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| v2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
| v2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
| v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| v2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
| v2v_split_uncond = gr.Checkbox(label="Split Unconditional (for SkyReels)", value=True) | |
| ### SKYREELS | |
| with gr.Tab(label="SkyReels-i2v") as skyreels_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| skyreels_prompt = gr.Textbox( | |
| scale=3, | |
| label="Enter your prompt", | |
| value="A person walking on a beach at sunset", | |
| lines=5 | |
| ) | |
| skyreels_negative_prompt = gr.Textbox( | |
| scale=3, | |
| label="Negative Prompt", | |
| value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", | |
| lines=3 | |
| ) | |
| with gr.Column(scale=1): | |
| skyreels_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| skyreels_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| skyreels_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| skyreels_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| skyreels_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| skyreels_stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| skyreels_input = gr.Image(label="Input Image (optional)", type="filepath") | |
| with gr.Row(): | |
| skyreels_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) | |
| skyreels_input_folder = gr.Textbox( | |
| label="Image Folder Path", | |
| placeholder="Path to folder containing images", | |
| visible=False | |
| ) | |
| skyreels_folder_status = gr.Textbox( | |
| label="Folder Status", | |
| placeholder="Status will appear here", | |
| interactive=False, | |
| visible=False | |
| ) | |
| skyreels_validate_folder_btn = gr.Button("Validate Folder", visible=False) | |
| skyreels_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") | |
| # Scale slider as percentage | |
| skyreels_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
| skyreels_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
| # Width and height inputs | |
| with gr.Row(): | |
| skyreels_width = gr.Number(label="New Width", value=544, step=16) | |
| skyreels_calc_height_btn = gr.Button("→") | |
| skyreels_calc_width_btn = gr.Button("←") | |
| skyreels_height = gr.Number(label="New Height", value=544, step=16) | |
| skyreels_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) | |
| skyreels_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) | |
| skyreels_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
| skyreels_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) | |
| skyreels_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=6.0) | |
| skyreels_embedded_cfg_scale = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, label="Embedded CFG Scale", value=1.0) | |
| with gr.Column(): | |
| skyreels_output = gr.Gallery( | |
| label="Generated Videos (Click to select)", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_label=True, | |
| elem_id="gallery", | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| skyreels_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
| # Add LoRA section for SKYREELS | |
| skyreels_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| skyreels_lora_weights = [] | |
| skyreels_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| skyreels_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| skyreels_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| skyreels_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| skyreels_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| skyreels_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
| skyreels_model = gr.Dropdown( | |
| label="DiT Model", | |
| choices=get_dit_models("skyreels"), | |
| value="skyreels_hunyuan_i2v_bf16.safetensors", | |
| allow_custom_value=True, | |
| interactive=True | |
| ) | |
| skyreels_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
| skyreels_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
| skyreels_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
| skyreels_save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| skyreels_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| skyreels_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| skyreels_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
| skyreels_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
| skyreels_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| skyreels_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
| skyreels_split_uncond = gr.Checkbox(label="Split Unconditional", value=True) | |
| # WanX Image to Video Tab | |
| with gr.Tab(id=4, label="WanX-i2v") as wanx_i2v_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| wanx_prompt = gr.Textbox( | |
| scale=3, | |
| label="Enter your prompt", | |
| value="A person walking on a beach at sunset", | |
| lines=5 | |
| ) | |
| wanx_negative_prompt = gr.Textbox( | |
| scale=3, | |
| label="Negative Prompt", | |
| value="", | |
| lines=3, | |
| info="Leave empty to use default negative prompt" | |
| ) | |
| with gr.Column(scale=1): | |
| wanx_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| wanx_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| wanx_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| wanx_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| wanx_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| wanx_stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| wanx_input = gr.Image(label="Input Image", type="filepath") | |
| with gr.Row(): | |
| wanx_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) | |
| wanx_input_folder = gr.Textbox( | |
| label="Image Folder Path", | |
| placeholder="Path to folder containing images", | |
| visible=False | |
| ) | |
| wanx_folder_status = gr.Textbox( | |
| label="Folder Status", | |
| placeholder="Status will appear here", | |
| interactive=False, | |
| visible=False | |
| ) | |
| wanx_validate_folder_btn = gr.Button("Validate Folder", visible=False) | |
| wanx_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
| wanx_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
| # Width and height display | |
| with gr.Row(): | |
| wanx_width = gr.Number(label="Width", value=832, interactive=True) | |
| wanx_calc_height_btn = gr.Button("→") | |
| wanx_calc_width_btn = gr.Button("←") | |
| wanx_height = gr.Number(label="Height", value=480, interactive=True) | |
| wanx_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") | |
| wanx_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) | |
| wanx_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) | |
| wanx_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) | |
| wanx_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=3.0, | |
| info="Recommended: 3.0 for 480p, 5.0 for others") | |
| wanx_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) | |
| with gr.Column(): | |
| wanx_output = gr.Gallery( | |
| label="Generated Videos (Click to select)", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_label=True, | |
| elem_id="gallery", | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| wanx_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") | |
| wanx_send_last_frame_btn = gr.Button("Send Last Frame to Input") | |
| wanx_extend_btn = gr.Button("Extend Video") | |
| wanx_frames_to_check = gr.Slider(minimum=1, maximum=100, step=1, value=30, | |
| label="Frames to Check from End", | |
| info="Number of frames from the end to check for sharpness") | |
| wanx_send_sharpest_frame_btn = gr.Button("Extract Sharpest Frame") | |
| wanx_trim_and_extend_btn = gr.Button("Trim Video & Prepare for Extension") | |
| wanx_sharpest_frame_status = gr.Textbox(label="Status", interactive=False) | |
| # Add a new button for directly extending with the trimmed video | |
| wanx_extend_with_trimmed_btn = gr.Button("Extend with Trimmed Video") | |
| # Add LoRA section for WanX-i2v similar to other tabs | |
| wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| wanx_lora_weights = [] | |
| wanx_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| wanx_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| wanx_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| wanx_task = gr.Dropdown( | |
| label="Task", | |
| choices=["i2v-14B"], | |
| value="i2v-14B", | |
| info="Currently only i2v-14B is supported" | |
| ) | |
| wanx_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_i2v_480p_14B_bf16.safetensors") | |
| wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") | |
| wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") | |
| wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") | |
| wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| wanx_save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") | |
| wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) | |
| wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) | |
| wanx_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) | |
| #WanX-t2v Tab | |
| # WanX Text to Video Tab | |
| with gr.Tab(id=5, label="WanX-t2v") as wanx_t2v_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| wanx_t2v_prompt = gr.Textbox( | |
| scale=3, | |
| label="Enter your prompt", | |
| value="A person walking on a beach at sunset", | |
| lines=5 | |
| ) | |
| wanx_t2v_negative_prompt = gr.Textbox( | |
| scale=3, | |
| label="Negative Prompt", | |
| value="", | |
| lines=3, | |
| info="Leave empty to use default negative prompt" | |
| ) | |
| with gr.Column(scale=1): | |
| wanx_t2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
| wanx_t2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
| with gr.Column(scale=2): | |
| wanx_t2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
| wanx_t2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
| with gr.Row(): | |
| wanx_t2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
| wanx_t2v_stop_btn = gr.Button("Stop Generation", variant="stop") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| wanx_t2v_width = gr.Number(label="Width", value=832, interactive=True, info="Should be divisible by 32") | |
| wanx_t2v_height = gr.Number(label="Height", value=480, interactive=True, info="Should be divisible by 32") | |
| wanx_t2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") | |
| wanx_t2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) | |
| wanx_t2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) | |
| wanx_t2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) | |
| wanx_t2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, | |
| info="Recommended: 3.0 for I2V with 480p, 5.0 for others") | |
| wanx_t2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) | |
| with gr.Column(): | |
| wanx_t2v_output = gr.Gallery( | |
| label="Generated Videos (Click to select)", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_label=True, | |
| elem_id="gallery", | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
| # Add LoRA section for WanX-t2v | |
| wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| wanx_t2v_lora_weights = [] | |
| wanx_t2v_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| wanx_t2v_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| wanx_t2v_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
| wanx_t2v_task = gr.Dropdown( | |
| label="Task", | |
| choices=["t2v-1.3B", "t2v-14B", "t2i-14B"], | |
| value="t2v-14B", | |
| info="Select model size: t2v-1.3B is faster, t2v-14B has higher quality" | |
| ) | |
| wanx_t2v_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_t2v_14B_bf16.safetensors") | |
| wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") | |
| wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") | |
| wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") | |
| wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") | |
| with gr.Row(): | |
| wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
| wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") | |
| wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
| wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, | |
| info="Max 39 for 14B model, 29 for 1.3B model") | |
| wanx_t2v_fp8 = gr.Checkbox(label="Use FP8", value=True) | |
| wanx_t2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) | |
| #Video Info Tab | |
| with gr.Tab("Video Info") as video_info_tab: | |
| with gr.Row(): | |
| video_input = gr.Video(label="Upload Video", interactive=True) | |
| metadata_output = gr.JSON(label="Generation Parameters") | |
| with gr.Row(): | |
| send_to_t2v_btn = gr.Button("Send to Text2Video", variant="primary") | |
| send_to_v2v_btn = gr.Button("Send to Video2Video", variant="primary") | |
| send_to_wanx_i2v_btn = gr.Button("Send to WanX-i2v", variant="primary") | |
| send_to_wanx_t2v_btn = gr.Button("Send to WanX-t2v", variant="primary") | |
| with gr.Row(): | |
| status = gr.Textbox(label="Status", interactive=False) | |
| #Merge Model's tab | |
| with gr.Tab("Convert LoRA") as convert_lora_tab: | |
| def suggest_output_name(file_obj) -> str: | |
| """Generate suggested output name from input file""" | |
| if not file_obj: | |
| return "" | |
| # Get input filename without extension and add MUSUBI | |
| base_name = os.path.splitext(os.path.basename(file_obj.name))[0] | |
| return f"{base_name}_MUSUBI" | |
| def convert_lora(input_file, output_name: str, target_format: str) -> str: | |
| """Convert LoRA file to specified format""" | |
| try: | |
| if not input_file: | |
| return "Error: No input file selected" | |
| # Ensure output directory exists | |
| os.makedirs("lora", exist_ok=True) | |
| # Construct output path | |
| output_path = os.path.join("lora", f"{output_name}.safetensors") | |
| # Build command | |
| cmd = [ | |
| sys.executable, | |
| "convert_lora.py", | |
| "--input", input_file.name, | |
| "--output", output_path, | |
| "--target", target_format | |
| ] | |
| print(f"Converting {input_file.name} to {output_path}") | |
| # Execute conversion | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| check=True | |
| ) | |
| if os.path.exists(output_path): | |
| return f"Successfully converted LoRA to {output_path}" | |
| else: | |
| return "Error: Output file not created" | |
| except subprocess.CalledProcessError as e: | |
| return f"Error during conversion: {e.stderr}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| with gr.Row(): | |
| input_file = gr.File(label="Input LoRA File", file_types=[".safetensors"]) | |
| output_name = gr.Textbox(label="Output Name", placeholder="Output filename (without extension)") | |
| format_radio = gr.Radio( | |
| choices=["default", "other"], | |
| value="default", | |
| label="Target Format", | |
| info="Choose 'default' for H1111/MUSUBI format or 'other' for diffusion pipe format" | |
| ) | |
| with gr.Row(): | |
| convert_btn = gr.Button("Convert LoRA", variant="primary") | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| # Automatically update output name when file is selected | |
| input_file.change( | |
| fn=suggest_output_name, | |
| inputs=[input_file], | |
| outputs=[output_name] | |
| ) | |
| # Handle conversion | |
| convert_btn.click( | |
| fn=convert_lora, | |
| inputs=[input_file, output_name, format_radio], | |
| outputs=status_output | |
| ) | |
| with gr.Tab("Model Merging") as model_merge_tab: | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Model selection | |
| dit_model = gr.Dropdown( | |
| label="Base DiT Model", | |
| choices=["mp_rank_00_model_states.pt"], | |
| value="mp_rank_00_model_states.pt", | |
| allow_custom_value=True, | |
| interactive=True | |
| ) | |
| merge_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Output model name | |
| output_model = gr.Textbox(label="Output Model Name", value="merged_model.safetensors") | |
| exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
| merge_btn = gr.Button("Merge Models", variant="primary") | |
| merge_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Row(): | |
| # LoRA selection section (similar to Text2Video) | |
| merge_lora_weights = [] | |
| merge_lora_multipliers = [] | |
| for i in range(4): | |
| with gr.Column(): | |
| merge_lora_weights.append(gr.Dropdown( | |
| label=f"LoRA {i+1}", | |
| choices=get_lora_options(), | |
| value="None", | |
| allow_custom_value=True, | |
| interactive=True | |
| )) | |
| merge_lora_multipliers.append(gr.Slider( | |
| label=f"Multiplier", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0 | |
| )) | |
| with gr.Row(): | |
| merge_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
| dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
| #Video Extension | |
| wanx_send_last_frame_btn.click( | |
| fn=send_last_frame_handler, | |
| inputs=[wanx_output, wanx_i2v_selected_index], | |
| outputs=[wanx_input, wanx_base_video] | |
| ) | |
| wanx_extend_btn.click( | |
| fn=prepare_for_batch_extension, | |
| inputs=[wanx_input, wanx_base_video, wanx_batch_size], | |
| outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] | |
| ).then( | |
| fn=wanx_batch_handler, | |
| inputs=[ | |
| gr.Checkbox(value=False), # Not using random folder | |
| wanx_prompt, wanx_negative_prompt, | |
| wanx_width, wanx_height, wanx_video_length, | |
| wanx_fps, wanx_infer_steps, wanx_flow_shift, | |
| wanx_guidance_scale, wanx_seed, wanx_batch_size, | |
| wanx_input_folder, # Not used but needed for function signature | |
| wanx_task, | |
| wanx_dit_path, wanx_vae_path, wanx_t5_path, | |
| wanx_clip_path, wanx_save_path, wanx_output_type, | |
| wanx_sample_solver, wanx_exclude_single_blocks, | |
| wanx_attn_mode, wanx_block_swap, wanx_fp8, | |
| wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, | |
| *wanx_lora_multipliers, wanx_input # Include input image | |
| ], | |
| outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] | |
| ).then( | |
| fn=concat_batch_videos, | |
| inputs=[wanx_base_video, wanx_output, wanx_save_path], | |
| outputs=[wanx_output, wanx_progress_text] | |
| ) | |
| # Extract and send sharpest frame to input | |
| wanx_send_sharpest_frame_btn.click( | |
| fn=send_sharpest_frame_handler, | |
| inputs=[wanx_output, wanx_i2v_selected_index, wanx_frames_to_check], | |
| outputs=[wanx_input, wanx_base_video, wanx_sharpest_frame_number, wanx_sharpest_frame_status] | |
| ) | |
| # Trim video to sharpest frame and prepare for extension | |
| wanx_trim_and_extend_btn.click( | |
| fn=trim_and_prepare_for_extension, | |
| inputs=[wanx_base_video, wanx_sharpest_frame_number, wanx_save_path], | |
| outputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status] | |
| ).then( | |
| fn=lambda path, status: (path, status if "Failed" in status else "Video trimmed successfully and ready for extension"), | |
| inputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status], | |
| outputs=[wanx_base_video, wanx_sharpest_frame_status] | |
| ) | |
| # Event handler for extending with the trimmed video | |
| wanx_extend_with_trimmed_btn.click( | |
| fn=prepare_for_batch_extension, | |
| inputs=[wanx_input, wanx_trimmed_video_path, wanx_batch_size], | |
| outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] | |
| ).then( | |
| fn=wanx_batch_handler, | |
| inputs=[ | |
| gr.Checkbox(value=False), # Not using random folder | |
| wanx_prompt, wanx_negative_prompt, | |
| wanx_width, wanx_height, wanx_video_length, | |
| wanx_fps, wanx_infer_steps, wanx_flow_shift, | |
| wanx_guidance_scale, wanx_seed, wanx_batch_size, | |
| wanx_input_folder, # Not used but needed for function signature | |
| wanx_task, | |
| wanx_dit_path, wanx_vae_path, wanx_t5_path, | |
| wanx_clip_path, wanx_save_path, wanx_output_type, | |
| wanx_sample_solver, wanx_exclude_single_blocks, | |
| wanx_attn_mode, wanx_block_swap, wanx_fp8, | |
| wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, | |
| *wanx_lora_multipliers, wanx_input # Include input image | |
| ], | |
| outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] | |
| ).then( | |
| fn=concat_batch_videos, | |
| inputs=[wanx_trimmed_video_path, wanx_output, wanx_save_path], | |
| outputs=[wanx_output, wanx_progress_text] | |
| ) | |
| #Video Info | |
| def handle_send_to_wanx_tab(metadata, target_tab): | |
| """Common handler for sending video parameters to WanX tabs""" | |
| if not metadata: | |
| return "No parameters to send", {} | |
| # Tab names for clearer messages | |
| tab_names = { | |
| 'wanx_i2v': 'WanX-i2v', | |
| 'wanx_t2v': 'WanX-t2v' | |
| } | |
| # Just pass through all parameters - we'll use them in the .then() function | |
| return f"Parameters ready for {tab_names.get(target_tab, target_tab)}", metadata | |
| def change_to_wanx_i2v_tab(): | |
| return gr.Tabs(selected=4) # WanX-i2v tab index | |
| def change_to_wanx_t2v_tab(): | |
| return gr.Tabs(selected=5) # WanX-t2v tab index | |
| send_to_wanx_i2v_btn.click( | |
| fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_i2v'), | |
| inputs=[metadata_output], | |
| outputs=[status, params_state] | |
| ).then( | |
| lambda params: [ | |
| params.get("prompt", ""), | |
| params.get("width", 832), | |
| params.get("height", 480), | |
| params.get("video_length", 81), | |
| params.get("fps", 16), | |
| params.get("infer_steps", 40), | |
| params.get("seed", -1), | |
| params.get("flow_shift", 3.0), | |
| params.get("guidance_scale", 5.0), | |
| params.get("attn_mode", "sdpa"), | |
| params.get("block_swap", 0), | |
| params.get("task", "i2v-14B") | |
| ] if params else [gr.update()]*12, | |
| inputs=params_state, | |
| outputs=[ | |
| wanx_prompt, | |
| wanx_width, | |
| wanx_height, | |
| wanx_video_length, | |
| wanx_fps, | |
| wanx_infer_steps, | |
| wanx_seed, | |
| wanx_flow_shift, | |
| wanx_guidance_scale, | |
| wanx_attn_mode, | |
| wanx_block_swap, | |
| wanx_task | |
| ] | |
| ).then( | |
| fn=change_to_wanx_i2v_tab, inputs=None, outputs=[tabs] | |
| ) | |
| # 3. Update the WanX-t2v button handler | |
| send_to_wanx_t2v_btn.click( | |
| fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_t2v'), | |
| inputs=[metadata_output], | |
| outputs=[status, params_state] | |
| ).then( | |
| lambda params: [ | |
| params.get("prompt", ""), | |
| params.get("width", 832), | |
| params.get("height", 480), | |
| params.get("video_length", 81), | |
| params.get("fps", 16), | |
| params.get("infer_steps", 50), | |
| params.get("seed", -1), | |
| params.get("flow_shift", 5.0), | |
| params.get("guidance_scale", 5.0), | |
| params.get("attn_mode", "sdpa"), | |
| params.get("block_swap", 0) | |
| ] if params else [gr.update()]*11, | |
| inputs=params_state, | |
| outputs=[ | |
| wanx_t2v_prompt, | |
| wanx_t2v_width, | |
| wanx_t2v_height, | |
| wanx_t2v_video_length, | |
| wanx_t2v_fps, | |
| wanx_t2v_infer_steps, | |
| wanx_t2v_seed, | |
| wanx_t2v_flow_shift, | |
| wanx_t2v_guidance_scale, | |
| wanx_t2v_attn_mode, | |
| wanx_t2v_block_swap | |
| ] | |
| ).then( | |
| fn=change_to_wanx_t2v_tab, inputs=None, outputs=[tabs] | |
| ) | |
| #text to video | |
| def change_to_tab_one(): | |
| return gr.Tabs(selected=1) #This will navigate | |
| #video to video | |
| def change_to_tab_two(): | |
| return gr.Tabs(selected=2) #This will navigate | |
| def change_to_skyreels_tab(): | |
| return gr.Tabs(selected=3) | |
| #SKYREELS TAB!!! | |
| # Add state management for dimensions | |
| def sync_skyreels_dimensions(width, height): | |
| return gr.update(value=width), gr.update(value=height) | |
| # Add this function to update the LoRA dropdowns in the SKYREELS tab | |
| def update_skyreels_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: | |
| new_choices = get_lora_options(lora_folder) | |
| weights = current_values[:4] | |
| multipliers = current_values[4:8] | |
| results = [] | |
| for i in range(4): | |
| weight = weights[i] if i < len(weights) else "None" | |
| multiplier = multipliers[i] if i < len(multipliers) else 1.0 | |
| if weight not in new_choices: | |
| weight = "None" | |
| results.extend([ | |
| gr.update(choices=new_choices, value=weight), | |
| gr.update(value=multiplier) | |
| ]) | |
| return results | |
| # Add this function to update the models dropdown in the SKYREELS tab | |
| def update_skyreels_model_dropdown(dit_folder: str) -> Dict: | |
| models = get_dit_models(dit_folder) | |
| return gr.update(choices=models, value=models[0] if models else None) | |
| # Add event handler for model dropdown refresh | |
| skyreels_dit_folder.change( | |
| fn=update_skyreels_model_dropdown, | |
| inputs=[skyreels_dit_folder], | |
| outputs=[skyreels_model] | |
| ) | |
| # Add handlers for the refresh button | |
| skyreels_refresh_btn.click( | |
| fn=update_skyreels_lora_dropdowns, | |
| inputs=[skyreels_lora_folder] + skyreels_lora_weights + skyreels_lora_multipliers, | |
| outputs=[drop for _ in range(4) for drop in [skyreels_lora_weights[_], skyreels_lora_multipliers[_]]] | |
| ) | |
| # Skyreels dimension handling | |
| def calculate_skyreels_width(height, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_width = math.floor((height * aspect_ratio) / 16) * 16 | |
| return gr.update(value=new_width) | |
| def calculate_skyreels_height(width, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_height = math.floor((width / aspect_ratio) / 16) * 16 | |
| return gr.update(value=new_height) | |
| def update_skyreels_from_scale(scale, original_dims): | |
| if not original_dims: | |
| return gr.update(), gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| new_w = math.floor((orig_w * scale / 100) / 16) * 16 | |
| new_h = math.floor((orig_h * scale / 100) / 16) * 16 | |
| return gr.update(value=new_w), gr.update(value=new_h) | |
| def update_skyreels_dimensions(image): | |
| if image is None: | |
| return "", gr.update(value=544), gr.update(value=544) | |
| img = Image.open(image) | |
| w, h = img.size | |
| w = (w // 16) * 16 | |
| h = (h // 16) * 16 | |
| return f"{w}x{h}", w, h | |
| def handle_skyreels_gallery_select(evt: gr.SelectData) -> int: | |
| return evt.index | |
| def send_skyreels_to_v2v( | |
| gallery: list, | |
| prompt: str, | |
| selected_index: int, | |
| width: int, | |
| height: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| lora1: str, | |
| lora2: str, | |
| lora3: str, | |
| lora4: str, | |
| lora1_multiplier: float, | |
| lora2_multiplier: float, | |
| lora3_multiplier: float, | |
| lora4_multiplier: float, | |
| negative_prompt: str = "" # Add this parameter | |
| ) -> Tuple: | |
| if not gallery or selected_index is None or selected_index >= len(gallery): | |
| return (None, "", width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, cfg_scale, lora1, lora2, lora3, lora4, | |
| lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, | |
| negative_prompt) # Add negative_prompt to return | |
| selected_item = gallery[selected_index] | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, cfg_scale, lora1, lora2, lora3, lora4, | |
| lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, | |
| negative_prompt) # Add negative_prompt to return | |
| # Add event handlers for the SKYREELS tab | |
| skyreels_prompt.change(fn=count_prompt_tokens, inputs=skyreels_prompt, outputs=skyreels_token_counter) | |
| skyreels_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| # Image input handling | |
| skyreels_input.change( | |
| fn=update_skyreels_dimensions, | |
| inputs=[skyreels_input], | |
| outputs=[skyreels_original_dims, skyreels_width, skyreels_height] | |
| ) | |
| skyreels_scale_slider.change( | |
| fn=update_skyreels_from_scale, | |
| inputs=[skyreels_scale_slider, skyreels_original_dims], | |
| outputs=[skyreels_width, skyreels_height] | |
| ) | |
| skyreels_calc_width_btn.click( | |
| fn=calculate_skyreels_width, | |
| inputs=[skyreels_height, skyreels_original_dims], | |
| outputs=[skyreels_width] | |
| ) | |
| skyreels_calc_height_btn.click( | |
| fn=calculate_skyreels_height, | |
| inputs=[skyreels_width, skyreels_original_dims], | |
| outputs=[skyreels_height] | |
| ) | |
| # Handle checkbox visibility toggling | |
| skyreels_use_random_folder.change( | |
| fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), | |
| inputs=[skyreels_use_random_folder], | |
| outputs=[skyreels_input_folder, skyreels_folder_status, skyreels_input] | |
| ) | |
| # Validate folder button click handler | |
| skyreels_validate_folder_btn.click( | |
| fn=lambda folder: get_random_image_from_folder(folder)[1], | |
| inputs=[skyreels_input_folder], | |
| outputs=[skyreels_folder_status] | |
| ) | |
| skyreels_use_random_folder.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=[skyreels_use_random_folder], | |
| outputs=[skyreels_validate_folder_btn] | |
| ) | |
| # Modify the skyreels_generate_btn.click event handler to use process_random_image_batch when folder mode is on | |
| skyreels_generate_btn.click( | |
| fn=batch_handler, | |
| inputs=[ | |
| skyreels_use_random_folder, | |
| # Rest of the arguments | |
| skyreels_prompt, | |
| skyreels_negative_prompt, | |
| skyreels_width, | |
| skyreels_height, | |
| skyreels_video_length, | |
| skyreels_fps, | |
| skyreels_infer_steps, | |
| skyreels_seed, | |
| skyreels_flow_shift, | |
| skyreels_guidance_scale, | |
| skyreels_embedded_cfg_scale, | |
| skyreels_batch_size, | |
| skyreels_input_folder, | |
| skyreels_dit_folder, | |
| skyreels_model, | |
| skyreels_vae, | |
| skyreels_te1, | |
| skyreels_te2, | |
| skyreels_save_path, | |
| skyreels_output_type, | |
| skyreels_attn_mode, | |
| skyreels_block_swap, | |
| skyreels_exclude_single_blocks, | |
| skyreels_use_split_attn, | |
| skyreels_use_fp8, | |
| skyreels_split_uncond, | |
| skyreels_lora_folder, | |
| *skyreels_lora_weights, | |
| *skyreels_lora_multipliers, | |
| skyreels_input # Add the input image path | |
| ], | |
| outputs=[skyreels_output, skyreels_batch_progress, skyreels_progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[skyreels_batch_size], | |
| outputs=skyreels_selected_index | |
| ) | |
| # Gallery selection handling | |
| skyreels_output.select( | |
| fn=handle_skyreels_gallery_select, | |
| outputs=skyreels_selected_index | |
| ) | |
| # Send to Video2Video handler | |
| skyreels_send_to_v2v_btn.click( | |
| fn=send_skyreels_to_v2v, | |
| inputs=[ | |
| skyreels_output, skyreels_prompt, skyreels_selected_index, | |
| skyreels_width, skyreels_height, skyreels_video_length, | |
| skyreels_fps, skyreels_infer_steps, skyreels_seed, | |
| skyreels_flow_shift, skyreels_guidance_scale | |
| ] + skyreels_lora_weights + skyreels_lora_multipliers + [skyreels_negative_prompt], # This is ok because skyreels_negative_prompt is a Gradio component | |
| outputs=[ | |
| v2v_input, v2v_prompt, v2v_width, v2v_height, | |
| v2v_video_length, v2v_fps, v2v_infer_steps, | |
| v2v_seed, v2v_flow_shift, v2v_cfg_scale | |
| ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] | |
| ).then( | |
| fn=change_to_tab_two, | |
| inputs=None, | |
| outputs=[tabs] | |
| ) | |
| # Refresh button handler | |
| skyreels_refresh_outputs = [skyreels_model] | |
| for i in range(4): | |
| skyreels_refresh_outputs.extend([skyreels_lora_weights[i], skyreels_lora_multipliers[i]]) | |
| skyreels_refresh_btn.click( | |
| fn=update_dit_and_lora_dropdowns, | |
| inputs=[skyreels_dit_folder, skyreels_lora_folder, skyreels_model] + skyreels_lora_weights + skyreels_lora_multipliers, | |
| outputs=skyreels_refresh_outputs | |
| ) | |
| def calculate_v2v_width(height, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_width = math.floor((height * aspect_ratio) / 16) * 16 # Ensure divisible by 16 | |
| return gr.update(value=new_width) | |
| def calculate_v2v_height(width, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_height = math.floor((width / aspect_ratio) / 16) * 16 # Ensure divisible by 16 | |
| return gr.update(value=new_height) | |
| def update_v2v_from_scale(scale, original_dims): | |
| if not original_dims: | |
| return gr.update(), gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Ensure divisible by 16 | |
| new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Ensure divisible by 16 | |
| return gr.update(value=new_w), gr.update(value=new_h) | |
| def update_v2v_dimensions(video): | |
| if video is None: | |
| return "", gr.update(value=544), gr.update(value=544) | |
| cap = cv2.VideoCapture(video) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| cap.release() | |
| # Make dimensions divisible by 16 | |
| w = (w // 16) * 16 | |
| h = (h // 16) * 16 | |
| return f"{w}x{h}", w, h | |
| # Event Handlers for Video to Video Tab | |
| v2v_input.change( | |
| fn=update_v2v_dimensions, | |
| inputs=[v2v_input], | |
| outputs=[v2v_original_dims, v2v_width, v2v_height] | |
| ) | |
| v2v_scale_slider.change( | |
| fn=update_v2v_from_scale, | |
| inputs=[v2v_scale_slider, v2v_original_dims], | |
| outputs=[v2v_width, v2v_height] | |
| ) | |
| v2v_calc_width_btn.click( | |
| fn=calculate_v2v_width, | |
| inputs=[v2v_height, v2v_original_dims], | |
| outputs=[v2v_width] | |
| ) | |
| v2v_calc_height_btn.click( | |
| fn=calculate_v2v_height, | |
| inputs=[v2v_width, v2v_original_dims], | |
| outputs=[v2v_height] | |
| ) | |
| ##Image 2 video dimension logic | |
| def calculate_width(height, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_width = math.floor((height * aspect_ratio) / 16) * 16 # Changed from 8 to 16 | |
| return gr.update(value=new_width) | |
| def calculate_height(width, original_dims): | |
| if not original_dims: | |
| return gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| aspect_ratio = orig_w / orig_h | |
| new_height = math.floor((width / aspect_ratio) / 16) * 16 # Changed from 8 to 16 | |
| return gr.update(value=new_height) | |
| def update_from_scale(scale, original_dims): | |
| if not original_dims: | |
| return gr.update(), gr.update() | |
| orig_w, orig_h = map(int, original_dims.split('x')) | |
| new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Changed from 8 to 16 | |
| new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Changed from 8 to 16 | |
| return gr.update(value=new_w), gr.update(value=new_h) | |
| def update_dimensions(image): | |
| if image is None: | |
| return "", gr.update(value=544), gr.update(value=544) | |
| img = Image.open(image) | |
| w, h = img.size | |
| # Make dimensions divisible by 16 | |
| w = (w // 16) * 16 # Changed from 8 to 16 | |
| h = (h // 16) * 16 # Changed from 8 to 16 | |
| return f"{w}x{h}", w, h | |
| i2v_input.change( | |
| fn=update_dimensions, | |
| inputs=[i2v_input], | |
| outputs=[original_dims, width, height] | |
| ) | |
| scale_slider.change( | |
| fn=update_from_scale, | |
| inputs=[scale_slider, original_dims], | |
| outputs=[width, height] | |
| ) | |
| calc_width_btn.click( | |
| fn=calculate_width, | |
| inputs=[height, original_dims], | |
| outputs=[width] | |
| ) | |
| calc_height_btn.click( | |
| fn=calculate_height, | |
| inputs=[width, original_dims], | |
| outputs=[height] | |
| ) | |
| # Function to get available DiT models | |
| def get_dit_models(dit_folder: str) -> List[str]: | |
| if not os.path.exists(dit_folder): | |
| return ["mp_rank_00_model_states.pt"] | |
| models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] | |
| models.sort(key=str.lower) | |
| return models if models else ["mp_rank_00_model_states.pt"] | |
| # Function to perform model merging | |
| def merge_models( | |
| dit_folder: str, | |
| dit_model: str, | |
| output_model: str, | |
| exclude_single_blocks: bool, | |
| merge_lora_folder: str, | |
| *lora_params # Will contain both weights and multipliers | |
| ) -> str: | |
| try: | |
| # Separate weights and multipliers | |
| num_loras = len(lora_params) // 2 | |
| weights = list(lora_params[:num_loras]) | |
| multipliers = list(lora_params[num_loras:]) | |
| # Filter out "None" selections | |
| valid_loras = [] | |
| for weight, mult in zip(weights, multipliers): | |
| if weight and weight != "None": | |
| valid_loras.append((os.path.join(merge_lora_folder, weight), mult)) | |
| if not valid_loras: | |
| return "No LoRA models selected for merging" | |
| # Create output path in the dit folder | |
| os.makedirs(dit_folder, exist_ok=True) | |
| output_path = os.path.join(dit_folder, output_model) | |
| # Prepare command | |
| cmd = [ | |
| sys.executable, | |
| "merge_lora.py", | |
| "--dit", os.path.join(dit_folder, dit_model), | |
| "--save_merged_model", output_path | |
| ] | |
| # Add LoRA weights and multipliers | |
| weights = [weight for weight, _ in valid_loras] | |
| multipliers = [str(mult) for _, mult in valid_loras] | |
| cmd.extend(["--lora_weight"] + weights) | |
| cmd.extend(["--lora_multiplier"] + multipliers) | |
| if exclude_single_blocks: | |
| cmd.append("--exclude_single_blocks") | |
| # Execute merge operation | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| check=True | |
| ) | |
| if os.path.exists(output_path): | |
| return f"Successfully merged model and saved to {output_path}" | |
| else: | |
| return "Error: Output file not created" | |
| except subprocess.CalledProcessError as e: | |
| return f"Error during merging: {e.stderr}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Update DiT model dropdown | |
| def update_dit_dropdown(dit_folder: str) -> Dict: | |
| models = get_dit_models(dit_folder) | |
| return gr.update(choices=models, value=models[0] if models else None) | |
| # Connect events | |
| merge_btn.click( | |
| fn=merge_models, | |
| inputs=[ | |
| dit_folder, | |
| dit_model, | |
| output_model, | |
| exclude_single_blocks, | |
| merge_lora_folder, | |
| *merge_lora_weights, | |
| *merge_lora_multipliers | |
| ], | |
| outputs=merge_status | |
| ) | |
| # Refresh buttons for both DiT and LoRA dropdowns | |
| merge_refresh_btn.click( | |
| fn=lambda f: update_dit_dropdown(f), | |
| inputs=[dit_folder], | |
| outputs=[dit_model] | |
| ) | |
| # LoRA refresh handling | |
| merge_refresh_outputs = [] | |
| for i in range(4): | |
| merge_refresh_outputs.extend([merge_lora_weights[i], merge_lora_multipliers[i]]) | |
| merge_refresh_btn.click( | |
| fn=update_lora_dropdowns, | |
| inputs=[merge_lora_folder] + merge_lora_weights + merge_lora_multipliers, | |
| outputs=merge_refresh_outputs | |
| ) | |
| # Event handlers | |
| prompt.change(fn=count_prompt_tokens, inputs=prompt, outputs=token_counter) | |
| v2v_prompt.change(fn=count_prompt_tokens, inputs=v2v_prompt, outputs=v2v_token_counter) | |
| stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| #Image_to_Video | |
| def image_to_video(image_path, output_path, width, height, frames=240): # Add width, height parameters | |
| img = Image.open(image_path) | |
| # Resize to the specified dimensions | |
| img_resized = img.resize((width, height), Image.LANCZOS) | |
| temp_image_path = os.path.join(os.path.dirname(output_path), "temp_resized_image.png") | |
| img_resized.save(temp_image_path) | |
| # Rest of function remains the same | |
| frame_rate = 24 | |
| duration = frames / frame_rate | |
| command = [ | |
| "ffmpeg", "-loop", "1", "-i", temp_image_path, "-c:v", "libx264", | |
| "-t", str(duration), "-pix_fmt", "yuv420p", | |
| "-vf", f"fps={frame_rate}", output_path | |
| ] | |
| try: | |
| subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| print(f"Video saved to {output_path}") | |
| return True | |
| except subprocess.CalledProcessError as e: | |
| print(f"An error occurred while creating the video: {e}") | |
| return False | |
| finally: | |
| # Clean up the temporary image file | |
| if os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| img.close() # Make sure to close the image file explicitly | |
| def generate_from_image( | |
| image_path, | |
| prompt, width, height, video_length, fps, infer_steps, | |
| seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder, strength, batch_size, *lora_params | |
| ): | |
| """Generate video from input image with progressive updates""" | |
| global stop_event | |
| stop_event.clear() | |
| # Create temporary video path | |
| temp_video_path = os.path.join(save_path, f"temp_{os.path.basename(image_path)}.mp4") | |
| try: | |
| # Convert image to video | |
| if not image_to_video(image_path, temp_video_path, width, height, frames=video_length): | |
| yield [], "Failed to create temporary video", "Error in video creation" | |
| return | |
| # Ensure video is fully written before proceeding | |
| time.sleep(1) | |
| if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: | |
| yield [], "Failed to create temporary video", "Temporary video file is empty or missing" | |
| return | |
| # Get video dimensions | |
| try: | |
| probe = ffmpeg.probe(temp_video_path) | |
| video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) | |
| if video_stream is None: | |
| raise ValueError("No video stream found") | |
| width = int(video_stream['width']) | |
| height = int(video_stream['height']) | |
| except Exception as e: | |
| yield [], f"Error reading video dimensions: {str(e)}", "Video processing error" | |
| return | |
| # Generate the video using the temporary file | |
| try: | |
| generator = process_single_video( | |
| prompt, width, height, batch_size, video_length, fps, infer_steps, | |
| seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder, *lora_params, video_path=temp_video_path, strength=strength | |
| ) | |
| # Forward all generator updates | |
| for videos, batch_text, progress_text in generator: | |
| yield videos, batch_text, progress_text | |
| except Exception as e: | |
| yield [], f"Error in video generation: {str(e)}", "Generation error" | |
| return | |
| except Exception as e: | |
| yield [], f"Unexpected error: {str(e)}", "Error occurred" | |
| return | |
| finally: | |
| # Clean up temporary file | |
| try: | |
| if os.path.exists(temp_video_path): | |
| os.remove(temp_video_path) | |
| except Exception: | |
| pass # Ignore cleanup errors | |
| # Add event handlers | |
| i2v_prompt.change(fn=count_prompt_tokens, inputs=i2v_prompt, outputs=i2v_token_counter) | |
| i2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| def handle_i2v_gallery_select(evt: gr.SelectData) -> int: | |
| """Track selected index when I2V gallery item is clicked""" | |
| return evt.index | |
| def send_i2v_to_v2v( | |
| gallery: list, | |
| prompt: str, | |
| selected_index: int, | |
| width: int, | |
| height: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| lora1: str, | |
| lora2: str, | |
| lora3: str, | |
| lora4: str, | |
| lora1_multiplier: float, | |
| lora2_multiplier: float, | |
| lora3_multiplier: float, | |
| lora4_multiplier: float | |
| ) -> Tuple[Optional[str], str, int, int, int, int, int, int, float, float, str, str, str, str, float, float, float, float]: | |
| """Send the selected video and parameters from Image2Video tab to Video2Video tab""" | |
| if not gallery or selected_index is None or selected_index >= len(gallery): | |
| return None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, \ | |
| lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier | |
| selected_item = gallery[selected_index] | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| # Final cleanup for Gradio Video component | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| # Use the original width and height without doubling | |
| return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, | |
| flow_shift, cfg_scale, lora1, lora2, lora3, lora4, | |
| lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier) | |
| # Generate button handler | |
| i2v_generate_btn.click( | |
| fn=process_batch, | |
| inputs=[ | |
| i2v_prompt, width, height, | |
| i2v_batch_size, i2v_video_length, | |
| i2v_fps, i2v_infer_steps, i2v_seed, i2v_dit_folder, i2v_model, i2v_vae, i2v_te1, i2v_te2, | |
| i2v_save_path, i2v_flow_shift, i2v_cfg_scale, i2v_output_type, i2v_attn_mode, | |
| i2v_block_swap, i2v_exclude_single_blocks, i2v_use_split_attn, i2v_lora_folder, | |
| *i2v_lora_weights, *i2v_lora_multipliers, i2v_input, i2v_strength, i2v_use_fp8 | |
| ], | |
| outputs=[i2v_output, i2v_batch_progress, i2v_progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[i2v_batch_size], | |
| outputs=i2v_selected_index | |
| ) | |
| # Send to Video2Video | |
| i2v_output.select( | |
| fn=handle_i2v_gallery_select, | |
| outputs=i2v_selected_index | |
| ) | |
| i2v_send_to_v2v_btn.click( | |
| fn=send_i2v_to_v2v, | |
| inputs=[ | |
| i2v_output, i2v_prompt, i2v_selected_index, | |
| width, height, | |
| i2v_video_length, i2v_fps, i2v_infer_steps, | |
| i2v_seed, i2v_flow_shift, i2v_cfg_scale | |
| ] + i2v_lora_weights + i2v_lora_multipliers, | |
| outputs=[ | |
| v2v_input, v2v_prompt, | |
| v2v_width, v2v_height, | |
| v2v_video_length, v2v_fps, v2v_infer_steps, | |
| v2v_seed, v2v_flow_shift, v2v_cfg_scale | |
| ] + v2v_lora_weights + v2v_lora_multipliers | |
| ).then( | |
| fn=change_to_tab_two, inputs=None, outputs=[tabs] | |
| ) | |
| #Video Info | |
| def clean_video_path(video_path) -> str: | |
| """Extract clean video path from Gradio's various return formats""" | |
| print(f"Input video_path: {video_path}, type: {type(video_path)}") | |
| if isinstance(video_path, dict): | |
| path = video_path.get("name", "") | |
| elif isinstance(video_path, (tuple, list)): | |
| path = video_path[0] | |
| elif isinstance(video_path, str): | |
| path = video_path | |
| else: | |
| path = "" | |
| print(f"Cleaned path: {path}") | |
| return path | |
| def handle_video_upload(video_path: str) -> Dict: | |
| """Handle video upload and metadata extraction""" | |
| if not video_path: | |
| return {}, "No video uploaded" | |
| metadata = extract_video_metadata(video_path) | |
| if not metadata: | |
| return {}, "No metadata found in video" | |
| return metadata, "Metadata extracted successfully" | |
| def get_video_info(video_path: str) -> dict: | |
| try: | |
| probe = ffmpeg.probe(video_path) | |
| video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video') | |
| width = int(video_info['width']) | |
| height = int(video_info['height']) | |
| fps = eval(video_info['r_frame_rate']) # This converts '30/1' to 30.0 | |
| # Calculate total frames | |
| duration = float(probe['format']['duration']) | |
| total_frames = int(duration * fps) | |
| # Ensure video length does not exceed 201 frames | |
| if total_frames > 201: | |
| total_frames = 201 | |
| duration = total_frames / fps # Adjust duration accordingly | |
| return { | |
| 'width': width, | |
| 'height': height, | |
| 'fps': fps, | |
| 'total_frames': total_frames, | |
| 'duration': duration # Might be useful in some contexts | |
| } | |
| except Exception as e: | |
| print(f"Error extracting video info: {e}") | |
| return {} | |
| def extract_video_details(video_path: str) -> Tuple[dict, str]: | |
| metadata = extract_video_metadata(video_path) | |
| video_details = get_video_info(video_path) | |
| # Combine metadata with video details | |
| for key, value in video_details.items(): | |
| if key not in metadata: | |
| metadata[key] = value | |
| # Ensure video length does not exceed 201 frames | |
| if 'video_length' in metadata: | |
| metadata['video_length'] = min(metadata['video_length'], 201) | |
| else: | |
| metadata['video_length'] = min(video_details.get('total_frames', 0), 201) | |
| # Return both the updated metadata and a status message | |
| return metadata, "Video details extracted successfully" | |
| def send_parameters_to_tab(metadata: Dict, target_tab: str) -> Tuple[str, Dict]: | |
| """Create parameter mapping for target tab""" | |
| if not metadata: | |
| return "No parameters to send", {} | |
| tab_name = "Text2Video" if target_tab == "t2v" else "Video2Video" | |
| try: | |
| mapping = create_parameter_transfer_map(metadata, target_tab) | |
| return f"Parameters ready for {tab_name}", mapping | |
| except Exception as e: | |
| return f"Error: {str(e)}", {} | |
| video_input.upload( | |
| fn=extract_video_details, | |
| inputs=video_input, | |
| outputs=[metadata_output, status] | |
| ) | |
| send_to_t2v_btn.click( | |
| fn=lambda m: send_parameters_to_tab(m, "t2v"), | |
| inputs=metadata_output, | |
| outputs=[status, params_state] | |
| ).then( | |
| fn=change_to_tab_one, inputs=None, outputs=[tabs] | |
| ).then( | |
| lambda params: [ | |
| params.get("prompt", ""), | |
| params.get("width", 544), | |
| params.get("height", 544), | |
| params.get("batch_size", 1), | |
| params.get("video_length", 25), | |
| params.get("fps", 24), | |
| params.get("infer_steps", 30), | |
| params.get("seed", -1), | |
| params.get("model", "hunyuan/mp_rank_00_model_states.pt"), | |
| params.get("vae", "hunyuan/pytorch_model.pt"), | |
| params.get("te1", "hunyuan/llava_llama3_fp16.safetensors"), | |
| params.get("te2", "hunyuan/clip_l.safetensors"), | |
| params.get("save_path", "outputs"), | |
| params.get("flow_shift", 11.0), | |
| params.get("cfg_scale", 7.0), | |
| params.get("output_type", "video"), | |
| params.get("attn_mode", "sdpa"), | |
| params.get("block_swap", "0"), | |
| *[params.get(f"lora{i+1}", "") for i in range(4)], | |
| *[params.get(f"lora{i+1}_multiplier", 1.0) for i in range(4)] | |
| ] if params else [gr.update()]*26, | |
| inputs=params_state, | |
| outputs=[prompt, width, height, batch_size, video_length, fps, infer_steps, seed, | |
| model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap] + lora_weights + lora_multipliers | |
| ) | |
| # Text to Video generation | |
| generate_btn.click( | |
| fn=process_batch, | |
| inputs=[ | |
| prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, | |
| seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
| output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
| lora_folder, *lora_weights, *lora_multipliers, gr.Textbox(visible=False), gr.Number(visible=False), use_fp8 | |
| ], | |
| outputs=[video_output, batch_progress, progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[batch_size], | |
| outputs=selected_index | |
| ) | |
| # Update gallery selection handling | |
| def handle_gallery_select(evt: gr.SelectData) -> int: | |
| return evt.index | |
| # Track selected index when gallery item is clicked | |
| video_output.select( | |
| fn=handle_gallery_select, | |
| outputs=selected_index | |
| ) | |
| # Track selected index when Video2Video gallery item is clicked | |
| def handle_v2v_gallery_select(evt: gr.SelectData) -> int: | |
| """Handle gallery selection without automatically updating the input""" | |
| return evt.index | |
| # Update the gallery selection event | |
| v2v_output.select( | |
| fn=handle_v2v_gallery_select, | |
| outputs=v2v_selected_index | |
| ) | |
| # Send button handler with gallery selection | |
| def handle_send_button( | |
| gallery: list, | |
| prompt: str, | |
| idx: int, | |
| width: int, | |
| height: int, | |
| batch_size: int, | |
| video_length: int, | |
| fps: int, | |
| infer_steps: int, | |
| seed: int, | |
| flow_shift: float, | |
| cfg_scale: float, | |
| lora1: str, | |
| lora2: str, | |
| lora3: str, | |
| lora4: str, | |
| lora1_multiplier: float, | |
| lora2_multiplier: float, | |
| lora3_multiplier: float, | |
| lora4_multiplier: float | |
| ) -> tuple: | |
| if not gallery or idx is None or idx >= len(gallery): | |
| return (None, "", width, height, batch_size, video_length, fps, infer_steps, | |
| seed, flow_shift, cfg_scale, | |
| lora1, lora2, lora3, lora4, | |
| lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, | |
| "") # Add empty string for negative_prompt in the return values | |
| # Auto-select first item if only one exists and no selection made | |
| if idx is None and len(gallery) == 1: | |
| idx = 0 | |
| selected_item = gallery[idx] | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, (tuple, list)): | |
| video_path = selected_item[0] | |
| else: | |
| video_path = selected_item | |
| # Final cleanup for Gradio Video component | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| return ( | |
| str(video_path), | |
| prompt, | |
| width, | |
| height, | |
| batch_size, | |
| video_length, | |
| fps, | |
| infer_steps, | |
| seed, | |
| flow_shift, | |
| cfg_scale, | |
| lora1, | |
| lora2, | |
| lora3, | |
| lora4, | |
| lora1_multiplier, | |
| lora2_multiplier, | |
| lora3_multiplier, | |
| lora4_multiplier, | |
| "" # Add empty string for negative_prompt | |
| ) | |
| send_t2v_to_v2v_btn.click( | |
| fn=handle_send_button, | |
| inputs=[ | |
| video_output, prompt, selected_index, | |
| t2v_width, t2v_height, batch_size, video_length, | |
| fps, infer_steps, seed, flow_shift, cfg_scale | |
| ] + lora_weights + lora_multipliers, # Remove the string here | |
| outputs=[ | |
| v2v_input, | |
| v2v_prompt, | |
| v2v_width, | |
| v2v_height, | |
| v2v_batch_size, | |
| v2v_video_length, | |
| v2v_fps, | |
| v2v_infer_steps, | |
| v2v_seed, | |
| v2v_flow_shift, | |
| v2v_cfg_scale | |
| ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] | |
| ).then( | |
| fn=change_to_tab_two, inputs=None, outputs=[tabs] | |
| ) | |
| def handle_send_to_v2v(metadata: dict, video_path: str) -> Tuple[str, dict, str]: | |
| """Handle both parameters and video transfer""" | |
| status_msg, params = send_parameters_to_tab(metadata, "v2v") | |
| return status_msg, params, video_path | |
| def handle_info_to_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: | |
| """Handle both parameters and video transfer from Video Info to V2V tab""" | |
| if not video_path: | |
| return "No video selected", {}, None | |
| status_msg, params = send_parameters_to_tab(metadata, "v2v") | |
| # Just return the path directly | |
| return status_msg, params, video_path | |
| # Send button click handler | |
| send_to_v2v_btn.click( | |
| fn=handle_info_to_v2v, | |
| inputs=[metadata_output, video_input], | |
| outputs=[status, params_state, v2v_input] | |
| ).then( | |
| lambda params: [ | |
| params.get("v2v_prompt", ""), | |
| params.get("v2v_width", 544), | |
| params.get("v2v_height", 544), | |
| params.get("v2v_batch_size", 1), | |
| params.get("v2v_video_length", 25), | |
| params.get("v2v_fps", 24), | |
| params.get("v2v_infer_steps", 30), | |
| params.get("v2v_seed", -1), | |
| params.get("v2v_model", "hunyuan/mp_rank_00_model_states.pt"), | |
| params.get("v2v_vae", "hunyuan/pytorch_model.pt"), | |
| params.get("v2v_te1", "hunyuan/llava_llama3_fp16.safetensors"), | |
| params.get("v2v_te2", "hunyuan/clip_l.safetensors"), | |
| params.get("v2v_save_path", "outputs"), | |
| params.get("v2v_flow_shift", 11.0), | |
| params.get("v2v_cfg_scale", 7.0), | |
| params.get("v2v_output_type", "video"), | |
| params.get("v2v_attn_mode", "sdpa"), | |
| params.get("v2v_block_swap", "0"), | |
| *[params.get(f"v2v_lora_weights[{i}]", "") for i in range(4)], | |
| *[params.get(f"v2v_lora_multipliers[{i}]", 1.0) for i in range(4)] | |
| ] if params else [gr.update()] * 26, | |
| inputs=params_state, | |
| outputs=[ | |
| v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, | |
| v2v_fps, v2v_infer_steps, v2v_seed, v2v_model, v2v_vae, v2v_te1, | |
| v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, | |
| v2v_attn_mode, v2v_block_swap | |
| ] + v2v_lora_weights + v2v_lora_multipliers | |
| ).then( | |
| lambda: print(f"Tabs object: {tabs}"), # Debug print | |
| outputs=None | |
| ).then( | |
| fn=change_to_tab_two, inputs=None, outputs=[tabs] | |
| ) | |
| # Handler for sending selected video from Video2Video gallery to input | |
| def handle_v2v_send_button(gallery: list, prompt: str, idx: int) -> Tuple[Optional[str], str]: | |
| """Send the currently selected video in V2V gallery to V2V input""" | |
| if not gallery or idx is None or idx >= len(gallery): | |
| return None, "" | |
| selected_item = gallery[idx] | |
| video_path = None | |
| # Handle different gallery item formats | |
| if isinstance(selected_item, tuple): | |
| video_path = selected_item[0] # Gallery returns (path, caption) | |
| elif isinstance(selected_item, dict): | |
| video_path = selected_item.get("name", selected_item.get("data", None)) | |
| elif isinstance(selected_item, str): | |
| video_path = selected_item | |
| if not video_path: | |
| return None, "" | |
| # Check if the file exists and is accessible | |
| if not os.path.exists(video_path): | |
| print(f"Warning: Video file not found at {video_path}") | |
| return None, "" | |
| return video_path, prompt | |
| v2v_send_to_input_btn.click( | |
| fn=handle_v2v_send_button, | |
| inputs=[v2v_output, v2v_prompt, v2v_selected_index], | |
| outputs=[v2v_input, v2v_prompt] | |
| ).then( | |
| lambda: gr.update(visible=True), # Ensure the video input is visible | |
| outputs=v2v_input | |
| ) | |
| # Video to Video generation | |
| v2v_generate_btn.click( | |
| fn=process_batch, | |
| inputs=[ | |
| v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, | |
| v2v_fps, v2v_infer_steps, v2v_seed, v2v_dit_folder, v2v_model, v2v_vae, v2v_te1, v2v_te2, | |
| v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, | |
| v2v_block_swap, v2v_exclude_single_blocks, v2v_use_split_attn, v2v_lora_folder, | |
| *v2v_lora_weights, *v2v_lora_multipliers, v2v_input, v2v_strength, | |
| v2v_negative_prompt, v2v_cfg_scale, v2v_split_uncond, v2v_use_fp8 | |
| ], | |
| outputs=[v2v_output, v2v_batch_progress, v2v_progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[v2v_batch_size], | |
| outputs=v2v_selected_index | |
| ) | |
| refresh_outputs = [model] # Add model dropdown to outputs | |
| for i in range(4): | |
| refresh_outputs.extend([lora_weights[i], lora_multipliers[i]]) | |
| refresh_btn.click( | |
| fn=update_dit_and_lora_dropdowns, | |
| inputs=[dit_folder, lora_folder, model] + lora_weights + lora_multipliers, | |
| outputs=refresh_outputs | |
| ) | |
| # Image2Video refresh | |
| i2v_refresh_outputs = [i2v_model] # Add model dropdown to outputs | |
| for i in range(4): | |
| i2v_refresh_outputs.extend([i2v_lora_weights[i], i2v_lora_multipliers[i]]) | |
| i2v_refresh_btn.click( | |
| fn=update_dit_and_lora_dropdowns, | |
| inputs=[i2v_dit_folder, i2v_lora_folder, i2v_model] + i2v_lora_weights + i2v_lora_multipliers, | |
| outputs=i2v_refresh_outputs | |
| ) | |
| # Video2Video refresh | |
| v2v_refresh_outputs = [v2v_model] # Add model dropdown to outputs | |
| for i in range(4): | |
| v2v_refresh_outputs.extend([v2v_lora_weights[i], v2v_lora_multipliers[i]]) | |
| v2v_refresh_btn.click( | |
| fn=update_dit_and_lora_dropdowns, | |
| inputs=[v2v_dit_folder, v2v_lora_folder, v2v_model] + v2v_lora_weights + v2v_lora_multipliers, | |
| outputs=v2v_refresh_outputs | |
| ) | |
| # WanX-i2v tab connections | |
| wanx_prompt.change(fn=count_prompt_tokens, inputs=wanx_prompt, outputs=wanx_token_counter) | |
| wanx_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| # Image input handling for WanX-i2v | |
| wanx_input.change( | |
| fn=update_wanx_image_dimensions, | |
| inputs=[wanx_input], | |
| outputs=[wanx_original_dims, wanx_width, wanx_height] | |
| ) | |
| # Scale slider handling for WanX-i2v | |
| wanx_scale_slider.change( | |
| fn=update_wanx_from_scale, | |
| inputs=[wanx_scale_slider, wanx_original_dims], | |
| outputs=[wanx_width, wanx_height] | |
| ) | |
| # Width/height calculation buttons for WanX-i2v | |
| wanx_calc_width_btn.click( | |
| fn=calculate_wanx_width, | |
| inputs=[wanx_height, wanx_original_dims], | |
| outputs=[wanx_width] | |
| ) | |
| wanx_calc_height_btn.click( | |
| fn=calculate_wanx_height, | |
| inputs=[wanx_width, wanx_original_dims], | |
| outputs=[wanx_height] | |
| ) | |
| # Add visibility toggle for the folder input components | |
| wanx_use_random_folder.change( | |
| fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), | |
| inputs=[wanx_use_random_folder], | |
| outputs=[wanx_input_folder, wanx_folder_status, wanx_validate_folder_btn, wanx_input] | |
| ) | |
| # Validate folder button handler | |
| wanx_validate_folder_btn.click( | |
| fn=lambda folder: get_random_image_from_folder(folder)[1], | |
| inputs=[wanx_input_folder], | |
| outputs=[wanx_folder_status] | |
| ) | |
| # Flow shift recommendation buttons | |
| wanx_recommend_flow_btn.click( | |
| fn=recommend_wanx_flow_shift, | |
| inputs=[wanx_width, wanx_height], | |
| outputs=[wanx_flow_shift] | |
| ) | |
| wanx_t2v_recommend_flow_btn.click( | |
| fn=recommend_wanx_flow_shift, | |
| inputs=[wanx_t2v_width, wanx_t2v_height], | |
| outputs=[wanx_t2v_flow_shift] | |
| ) | |
| # Generate button handler | |
| wanx_generate_btn.click( | |
| fn=wanx_batch_handler, | |
| inputs=[ | |
| wanx_use_random_folder, | |
| wanx_prompt, | |
| wanx_negative_prompt, | |
| wanx_width, | |
| wanx_height, | |
| wanx_video_length, | |
| wanx_fps, | |
| wanx_infer_steps, | |
| wanx_flow_shift, | |
| wanx_guidance_scale, | |
| wanx_seed, | |
| wanx_batch_size, | |
| wanx_input_folder, | |
| wanx_task, | |
| wanx_dit_path, | |
| wanx_vae_path, | |
| wanx_t5_path, | |
| wanx_clip_path, | |
| wanx_save_path, | |
| wanx_output_type, | |
| wanx_sample_solver, | |
| wanx_exclude_single_blocks, | |
| wanx_attn_mode, | |
| wanx_block_swap, | |
| wanx_fp8, | |
| wanx_fp8_t5, | |
| wanx_lora_folder, | |
| *wanx_lora_weights, | |
| *wanx_lora_multipliers, | |
| wanx_input # Include input image path for non-batch mode | |
| ], | |
| outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[wanx_batch_size], | |
| outputs=wanx_i2v_selected_index # Update to use correct state | |
| ) | |
| # Add refresh button handler for WanX-i2v tab | |
| wanx_refresh_outputs = [] | |
| for i in range(4): | |
| wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) | |
| wanx_refresh_btn.click( | |
| fn=update_lora_dropdowns, | |
| inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, | |
| outputs=wanx_refresh_outputs | |
| ) | |
| # Gallery selection handling | |
| wanx_output.select( | |
| fn=handle_wanx_gallery_select, | |
| inputs=[wanx_output], | |
| outputs=[wanx_i2v_selected_index, wanx_base_video] | |
| ) | |
| # Send to Video2Video handler | |
| wanx_send_to_v2v_btn.click( | |
| fn=send_wanx_to_v2v, | |
| inputs=[ | |
| wanx_output, # Gallery with videos | |
| wanx_prompt, # Prompt text | |
| wanx_i2v_selected_index, # Use the correct selected index state | |
| wanx_width, | |
| wanx_height, | |
| wanx_video_length, | |
| wanx_fps, | |
| wanx_infer_steps, | |
| wanx_seed, | |
| wanx_flow_shift, | |
| wanx_guidance_scale, | |
| wanx_negative_prompt | |
| ], | |
| outputs=[ | |
| v2v_input, # Video input in V2V tab | |
| v2v_prompt, # Prompt in V2V tab | |
| v2v_width, | |
| v2v_height, | |
| v2v_video_length, | |
| v2v_fps, | |
| v2v_infer_steps, | |
| v2v_seed, | |
| v2v_flow_shift, | |
| v2v_cfg_scale, | |
| v2v_negative_prompt | |
| ] | |
| ).then( | |
| fn=change_to_tab_two, # Function to switch to Video2Video tab | |
| inputs=None, | |
| outputs=[tabs] | |
| ) | |
| # Add state for T2V tab selected index | |
| wanx_t2v_selected_index = gr.State(value=None) | |
| # Connect prompt token counter | |
| wanx_t2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_t2v_prompt, outputs=wanx_t2v_token_counter) | |
| # Stop button handler | |
| wanx_t2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
| # Flow shift recommendation button | |
| wanx_t2v_recommend_flow_btn.click( | |
| fn=recommend_wanx_flow_shift, | |
| inputs=[wanx_t2v_width, wanx_t2v_height], | |
| outputs=[wanx_t2v_flow_shift] | |
| ) | |
| # Task change handler to update CLIP visibility and path | |
| def update_clip_visibility(task): | |
| is_i2v = "i2v" in task | |
| return gr.update(visible=is_i2v) | |
| wanx_t2v_task.change( | |
| fn=update_clip_visibility, | |
| inputs=[wanx_t2v_task], | |
| outputs=[wanx_t2v_clip_path] | |
| ) | |
| # Generate button handler for T2V | |
| wanx_t2v_generate_btn.click( | |
| fn=wanx_generate_video_batch, | |
| inputs=[ | |
| wanx_t2v_prompt, | |
| wanx_t2v_negative_prompt, | |
| wanx_t2v_width, | |
| wanx_t2v_height, | |
| wanx_t2v_video_length, | |
| wanx_t2v_fps, | |
| wanx_t2v_infer_steps, | |
| wanx_t2v_flow_shift, | |
| wanx_t2v_guidance_scale, | |
| wanx_t2v_seed, | |
| wanx_t2v_task, | |
| wanx_t2v_dit_path, | |
| wanx_t2v_vae_path, | |
| wanx_t2v_t5_path, | |
| wanx_t2v_clip_path, | |
| wanx_t2v_save_path, | |
| wanx_t2v_output_type, | |
| wanx_t2v_sample_solver, | |
| wanx_t2v_exclude_single_blocks, | |
| wanx_t2v_attn_mode, | |
| wanx_t2v_block_swap, | |
| wanx_t2v_fp8, | |
| wanx_t2v_fp8_t5, | |
| wanx_t2v_lora_folder, | |
| *wanx_t2v_lora_weights, | |
| *wanx_t2v_lora_multipliers, | |
| wanx_t2v_batch_size, | |
| # input_image is now optional and not included here | |
| ], | |
| outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], | |
| queue=True | |
| ).then( | |
| fn=lambda batch_size: 0 if batch_size == 1 else None, | |
| inputs=[wanx_t2v_batch_size], | |
| outputs=wanx_t2v_selected_index | |
| ) | |
| # Add refresh button handler for WanX-t2v tab | |
| wanx_t2v_refresh_outputs = [] | |
| for i in range(4): | |
| wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) | |
| wanx_t2v_refresh_btn.click( | |
| fn=update_lora_dropdowns, | |
| inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, | |
| outputs=wanx_t2v_refresh_outputs | |
| ) | |
| # Gallery selection handling | |
| wanx_t2v_output.select( | |
| fn=handle_wanx_t2v_gallery_select, | |
| outputs=wanx_t2v_selected_index | |
| ) | |
| # Send to Video2Video handler | |
| wanx_t2v_send_to_v2v_btn.click( | |
| fn=send_wanx_t2v_to_v2v, | |
| inputs=[ | |
| wanx_t2v_output, | |
| wanx_t2v_prompt, | |
| wanx_t2v_selected_index, | |
| wanx_t2v_width, | |
| wanx_t2v_height, | |
| wanx_t2v_video_length, | |
| wanx_t2v_fps, | |
| wanx_t2v_infer_steps, | |
| wanx_t2v_seed, | |
| wanx_t2v_flow_shift, | |
| wanx_t2v_guidance_scale, | |
| wanx_t2v_negative_prompt | |
| ], | |
| outputs=[ | |
| v2v_input, | |
| v2v_prompt, | |
| v2v_width, | |
| v2v_height, | |
| v2v_video_length, | |
| v2v_fps, | |
| v2v_infer_steps, | |
| v2v_seed, | |
| v2v_flow_shift, | |
| v2v_cfg_scale, | |
| v2v_negative_prompt | |
| ] | |
| ).then( | |
| fn=change_to_tab_two, | |
| inputs=None, | |
| outputs=[tabs] | |
| ) | |
| demo.queue().launch(server_name="0.0.0.0", share=False) |