Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """ | |
| MIMO - Complete HuggingFace Spaces Implementation | |
| Controllable Character Video Synthesis with Spatial Decomposed Modeling | |
| Complete features matching README_SETUP.md: | |
| - Character Image Animation (run_animate.py functionality) | |
| - Video Character Editing (run_edit.py functionality) | |
| - Real motion templates from assets/video_template/ | |
| - Auto GPU detection (T4/A10G/A100) | |
| - Auto model downloading | |
| - Human segmentation and background processing | |
| - Pose-guided video generation with occlusion handling | |
| """ | |
| # CRITICAL: Import spaces FIRST before any torch/CUDA operations | |
| # This prevents CUDA initialization errors on HuggingFace Spaces ZeroGPU | |
| try: | |
| import spaces | |
| HAS_SPACES = True | |
| print("β Spaces library loaded - ZeroGPU mode enabled") | |
| except ImportError: | |
| HAS_SPACES = False | |
| print("β οΈ Spaces library not available - running in local mode") | |
| import sys | |
| import os | |
| import json | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| from typing import List, Optional, Dict, Tuple | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import imageio | |
| from omegaconf import OmegaConf | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from transformers import CLIPVisionModelWithProjection | |
| # Add src to path for imports | |
| sys.path.append('./src') | |
| from src.models.pose_guider import PoseGuider | |
| from src.models.unet_2d_condition import UNet2DConditionModel | |
| from src.models.unet_3d_edit_bkfill import UNet3DConditionModel | |
| from src.pipelines.pipeline_pose2vid_long_edit_bkfill_roiclip import Pose2VideoPipeline | |
| from src.utils.util import get_fps, read_frames | |
| # Optional: human segmenter (requires tensorflow) | |
| try: | |
| from tools.human_segmenter import human_segmenter | |
| HAS_SEGMENTER = True | |
| except ImportError: | |
| print("β οΈ TensorFlow not available, human_segmenter disabled (will use fallback)") | |
| human_segmenter = None | |
| HAS_SEGMENTER = False | |
| from tools.util import ( | |
| load_mask_list, crop_img, pad_img, crop_human, | |
| crop_human_clip_auto_context, get_mask, load_video_fixed_fps, | |
| recover_bk, all_file | |
| ) | |
| # Global variables | |
| # CRITICAL: For HF Spaces ZeroGPU, keep device as "cpu" initially | |
| # Models will be moved to GPU only inside @spaces.GPU() decorated functions | |
| DEVICE = "cpu" # Don't initialize CUDA in main process | |
| MODEL_CACHE = "./models" | |
| ASSETS_CACHE = "./assets" | |
| # CRITICAL: Set memory optimization for PyTorch to avoid fragmentation | |
| # This helps ZeroGPU handle memory more efficiently | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| class CompleteMIMO: | |
| """Complete MIMO implementation matching README_SETUP.md functionality""" | |
| def __init__(self): | |
| self.pipe = None | |
| self.is_loaded = False | |
| self.segmenter = None | |
| self.mask_list = None | |
| self.weight_dtype = torch.float32 | |
| self._model_cache_valid = False # Track if models are loaded | |
| # Create cache directories | |
| os.makedirs(MODEL_CACHE, exist_ok=True) | |
| os.makedirs(ASSETS_CACHE, exist_ok=True) | |
| os.makedirs("./output", exist_ok=True) | |
| print(f"π MIMO initializing on {DEVICE}") | |
| if DEVICE == "cuda": | |
| print(f"π GPU: {torch.cuda.get_device_name()}") | |
| print(f"πΎ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") | |
| # Check if models are already loaded from previous session | |
| self._check_existing_models() | |
| def _check_existing_models(self): | |
| """Check if models are already downloaded and show status""" | |
| try: | |
| # Use the same path detection logic as load_model | |
| # This accounts for HuggingFace cache structure (models--org--name/snapshots/hash/) | |
| from pathlib import Path | |
| # Check if any model directories exist (either simple or HF cache structure) | |
| model_dirs = [ | |
| Path(f"{MODEL_CACHE}/stable-diffusion-v1-5"), | |
| Path(f"{MODEL_CACHE}/sd-vae-ft-mse"), | |
| Path(f"{MODEL_CACHE}/mimo_weights"), | |
| Path(f"{MODEL_CACHE}/image_encoder_full") | |
| ] | |
| # Also check for HuggingFace cache structure | |
| cache_patterns = [ | |
| "models--runwayml--stable-diffusion-v1-5", | |
| "models--stabilityai--sd-vae-ft-mse", | |
| "models--menyifang--MIMO", | |
| "models--lambdalabs--sd-image-variations-diffusers" | |
| ] | |
| models_found = 0 | |
| for pattern in cache_patterns: | |
| # Check if any directory contains this pattern | |
| for cache_dir in Path(MODEL_CACHE).rglob(pattern): | |
| if cache_dir.is_dir(): | |
| models_found += 1 | |
| break | |
| # Also check simple paths | |
| for model_dir in model_dirs: | |
| if model_dir.exists() and model_dir.is_dir(): | |
| models_found += 1 | |
| if models_found >= 3: # At least 3 major components found | |
| print(f"β Found {models_found} model components in cache - models persist across restarts!") | |
| self._model_cache_valid = True | |
| if not self.is_loaded: | |
| print("π‘ Models available - click 'Load Model' to activate") | |
| return True | |
| else: | |
| print(f"β οΈ Only found {models_found} model components - click 'Setup Models' to download") | |
| self._model_cache_valid = False | |
| return False | |
| except Exception as e: | |
| print(f"β οΈ Could not check existing models: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| self._model_cache_valid = False | |
| return False | |
| def download_models(self, progress_callback=None): | |
| """Download all required models matching README_SETUP.md requirements""" | |
| # CRITICAL: Disable hf_transfer to avoid download errors on HF Spaces | |
| # The hf_transfer backend can be problematic in Spaces environment | |
| os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '0' | |
| def update_progress(msg): | |
| if progress_callback: | |
| progress_callback(msg) | |
| print(f"π₯ {msg}") | |
| update_progress("π§ Disabled hf_transfer for stable downloads") | |
| downloaded_count = 0 | |
| total_steps = 7 | |
| try: | |
| # 1. Download MIMO models (main weights) - CRITICAL | |
| try: | |
| update_progress("Downloading MIMO main models...") | |
| snapshot_download( | |
| repo_id="menyifang/MIMO", | |
| cache_dir=f"{MODEL_CACHE}/mimo_weights", | |
| allow_patterns=["*.pth", "*.json", "*.md"], | |
| token=None | |
| ) | |
| downloaded_count += 1 | |
| update_progress(f"β MIMO models downloaded ({downloaded_count}/{total_steps})") | |
| except Exception as e: | |
| update_progress(f"β οΈ MIMO models download failed: {str(e)[:100]}") | |
| print(f"Error details: {e}") | |
| # 2. Download Stable Diffusion v1.5 (base model) - CRITICAL | |
| try: | |
| update_progress("Downloading Stable Diffusion v1.5...") | |
| snapshot_download( | |
| repo_id="runwayml/stable-diffusion-v1-5", | |
| cache_dir=f"{MODEL_CACHE}/stable-diffusion-v1-5", | |
| allow_patterns=["**/*.json", "**/*.bin", "**/*.safetensors", "**/*.txt"], | |
| ignore_patterns=["*.msgpack", "*.h5", "*.ot"], | |
| token=None | |
| ) | |
| downloaded_count += 1 | |
| update_progress(f"β SD v1.5 downloaded ({downloaded_count}/{total_steps})") | |
| except Exception as e: | |
| update_progress(f"β οΈ SD v1.5 download failed: {str(e)[:100]}") | |
| print(f"Error details: {e}") | |
| # 3. Download VAE (improved autoencoder) - CRITICAL | |
| try: | |
| update_progress("Downloading sd-vae-ft-mse...") | |
| snapshot_download( | |
| repo_id="stabilityai/sd-vae-ft-mse", | |
| cache_dir=f"{MODEL_CACHE}/sd-vae-ft-mse", | |
| token=None | |
| ) | |
| downloaded_count += 1 | |
| update_progress(f"β VAE downloaded ({downloaded_count}/{total_steps})") | |
| except Exception as e: | |
| update_progress(f"β οΈ VAE download failed: {str(e)[:100]}") | |
| print(f"Error details: {e}") | |
| # 4. Download image encoder (for reference image processing) - CRITICAL | |
| try: | |
| update_progress("Downloading image encoder...") | |
| snapshot_download( | |
| repo_id="lambdalabs/sd-image-variations-diffusers", | |
| cache_dir=f"{MODEL_CACHE}/image_encoder_full", | |
| allow_patterns=["image_encoder/**"], | |
| token=None | |
| ) | |
| downloaded_count += 1 | |
| update_progress(f"β Image encoder downloaded ({downloaded_count}/{total_steps})") | |
| except Exception as e: | |
| update_progress(f"β οΈ Image encoder download failed: {str(e)[:100]}") | |
| print(f"Error details: {e}") | |
| # 5. Download human segmenter (for background separation) - OPTIONAL | |
| try: | |
| update_progress("Downloading human segmenter...") | |
| os.makedirs(ASSETS_CACHE, exist_ok=True) | |
| if not os.path.exists(f"{ASSETS_CACHE}/matting_human.pb"): | |
| hf_hub_download( | |
| repo_id="menyifang/MIMO", | |
| filename="matting_human.pb", | |
| cache_dir=ASSETS_CACHE, | |
| local_dir=ASSETS_CACHE, | |
| token=None | |
| ) | |
| downloaded_count += 1 | |
| update_progress(f"β Human segmenter downloaded ({downloaded_count}/{total_steps})") | |
| except Exception as e: | |
| update_progress(f"β οΈ Human segmenter download failed (optional): {str(e)[:100]}") | |
| print(f"Will use fallback segmentation. Error: {e}") | |
| # 6. Setup video templates directory - OPTIONAL | |
| # Note: Templates are not available in the HuggingFace MIMO repo | |
| # Users need to manually upload them or use reference image only | |
| try: | |
| update_progress("Setting up video templates...") | |
| os.makedirs("./assets/video_template", exist_ok=True) | |
| # Check if any templates already exist (manually uploaded) | |
| existing_templates = [] | |
| try: | |
| for item in os.listdir("./assets/video_template"): | |
| template_path = os.path.join("./assets/video_template", item) | |
| if os.path.isdir(template_path) and os.path.exists(os.path.join(template_path, "sdc.mp4")): | |
| existing_templates.append(item) | |
| except: | |
| pass | |
| if existing_templates: | |
| update_progress(f"β Found {len(existing_templates)} existing templates") | |
| downloaded_count += 1 | |
| else: | |
| update_progress("βΉοΈ No video templates found (optional - see TEMPLATES_SETUP.md)") | |
| print("π‘ Templates are optional. You can:") | |
| print(" 1. Use reference image only (no template needed)") | |
| print(" 2. Manually upload templates to assets/video_template/") | |
| print(" 3. See TEMPLATES_SETUP.md for instructions") | |
| except Exception as e: | |
| update_progress(f"β οΈ Template setup warning: {str(e)[:100]}") | |
| print("π‘ Templates are optional - app will work without them") | |
| # 7. Create necessary directories | |
| try: | |
| update_progress("Setting up directories...") | |
| os.makedirs("./assets/masks", exist_ok=True) | |
| os.makedirs("./output", exist_ok=True) | |
| downloaded_count += 1 | |
| update_progress(f"β Directories created ({downloaded_count}/{total_steps})") | |
| except Exception as e: | |
| print(f"Directory creation warning: {e}") | |
| # Check if we have minimum requirements | |
| if downloaded_count >= 4: # At least MIMO, SD, VAE, and image encoder | |
| update_progress(f"β Setup complete! ({downloaded_count}/{total_steps} steps successful)") | |
| # Update cache validity flag after successful download | |
| self._model_cache_valid = True | |
| print("β Model cache is now valid - 'Load Model' button will work") | |
| return True | |
| else: | |
| update_progress(f"β οΈ Partial download ({downloaded_count}/{total_steps}). Some features may not work.") | |
| # Still set cache valid if we got some models | |
| if downloaded_count > 0: | |
| self._model_cache_valid = True | |
| return downloaded_count > 0 # Return True if at least something downloaded | |
| except Exception as e: | |
| error_msg = f"β Download failed: {str(e)}" | |
| update_progress(error_msg) | |
| print(f"\n{'='*60}") | |
| print("ERROR DETAILS:") | |
| traceback.print_exc() | |
| print(f"{'='*60}\n") | |
| return False | |
| def load_model(self, progress_callback=None): | |
| """Load MIMO model with complete functionality""" | |
| def update_progress(msg): | |
| if progress_callback: | |
| progress_callback(msg) | |
| print(f"π {msg}") | |
| try: | |
| if self.is_loaded: | |
| update_progress("β Model already loaded") | |
| return True | |
| # Check if model files exist and find actual paths | |
| update_progress("Checking model files...") | |
| # Helper function to find model in cache | |
| def find_model_path(primary_path, model_name, search_patterns=None): | |
| """Find model in cache, checking multiple possible locations""" | |
| # Check primary path first | |
| if os.path.exists(primary_path): | |
| # Verify it's a valid model directory (has config.json or model files) | |
| try: | |
| has_config = os.path.exists(os.path.join(primary_path, "config.json")) | |
| has_model_files = any(f.endswith(('.bin', '.safetensors', '.pth')) for f in os.listdir(primary_path) if os.path.isfile(os.path.join(primary_path, f))) | |
| if has_config or has_model_files: | |
| update_progress(f"β Found {model_name} at primary path") | |
| return primary_path | |
| else: | |
| # Primary path exists but might be a cache directory - check inside | |
| update_progress(f"β οΈ Primary path exists but appears to be a cache directory, searching inside...") | |
| # Check if it contains a models--org--name subdirectory | |
| if search_patterns: | |
| for pattern in search_patterns: | |
| # Extract just the directory name from pattern | |
| cache_dir_name = pattern.split('/')[-1] if '/' in pattern else pattern | |
| cache_subdir = os.path.join(primary_path, cache_dir_name) | |
| if os.path.exists(cache_subdir): | |
| update_progress(f" Found cache subdir: {cache_dir_name}") | |
| # Check in snapshots | |
| snap_path = os.path.join(cache_subdir, "snapshots") | |
| if os.path.exists(snap_path): | |
| try: | |
| snapshot_dirs = [d for d in os.listdir(snap_path) if os.path.isdir(os.path.join(snap_path, d))] | |
| if snapshot_dirs: | |
| full_path = os.path.join(snap_path, snapshot_dirs[0]) | |
| update_progress(f" Checking snapshot: {snapshot_dirs[0]}") | |
| # Check if this is a valid model directory | |
| # For SD models, may have subdirectories (unet, vae, etc.) | |
| has_config = os.path.exists(os.path.join(full_path, "config.json")) | |
| has_model_index = os.path.exists(os.path.join(full_path, "model_index.json")) | |
| has_subdirs = any(os.path.isdir(os.path.join(full_path, d)) for d in os.listdir(full_path)) | |
| has_model_files = any(f.endswith(('.bin', '.safetensors', '.pth')) for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))) | |
| if has_config or has_model_index or has_model_files or has_subdirs: | |
| update_progress(f"β Found {model_name} in snapshot: {full_path}") | |
| return full_path | |
| else: | |
| update_progress(f" β οΈ Snapshot exists but appears empty or invalid") | |
| except Exception as e: | |
| update_progress(f"β οΈ Error in snapshot: {e}") | |
| except Exception as e: | |
| update_progress(f"β οΈ Error checking primary path: {e}") | |
| # Check HF cache structure in MODEL_CACHE root | |
| if search_patterns: | |
| for pattern in search_patterns: | |
| alt_path = os.path.join(MODEL_CACHE, pattern) | |
| if os.path.exists(alt_path): | |
| update_progress(f" Checking cache: {pattern}") | |
| # Check in snapshots subdirectory | |
| snap_path = os.path.join(alt_path, "snapshots") | |
| if os.path.exists(snap_path): | |
| try: | |
| snapshot_dirs = [d for d in os.listdir(snap_path) if os.path.isdir(os.path.join(snap_path, d))] | |
| if snapshot_dirs: | |
| full_path = os.path.join(snap_path, snapshot_dirs[0]) | |
| # Check for various indicators of valid model | |
| has_config = os.path.exists(os.path.join(full_path, "config.json")) | |
| has_model_index = os.path.exists(os.path.join(full_path, "model_index.json")) | |
| has_subdirs = any(os.path.isdir(os.path.join(full_path, d)) for d in os.listdir(full_path)) | |
| has_model_files = any(f.endswith(('.bin', '.safetensors', '.pth')) for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))) | |
| if has_config or has_model_index or has_model_files or has_subdirs: | |
| update_progress(f"β Found {model_name} in snapshot: {full_path}") | |
| return full_path | |
| except Exception as e: | |
| update_progress(f"β οΈ Error searching snapshots: {e}") | |
| update_progress(f"β οΈ Could not find {model_name} in any location") | |
| return None # Find actual model paths | |
| vae_path = find_model_path( | |
| f"{MODEL_CACHE}/sd-vae-ft-mse", | |
| "VAE", | |
| ["models--stabilityai--sd-vae-ft-mse"] | |
| ) | |
| sd_path = find_model_path( | |
| f"{MODEL_CACHE}/stable-diffusion-v1-5", | |
| "SD v1.5", | |
| ["models--runwayml--stable-diffusion-v1-5"] | |
| ) | |
| # Find Image Encoder - handle HF cache structure | |
| encoder_path = None | |
| update_progress(f"π Searching for Image Encoder...") | |
| # Primary search: Check if image_encoder_full contains HF cache structure | |
| image_encoder_base = f"{MODEL_CACHE}/image_encoder_full" | |
| if os.path.exists(image_encoder_base): | |
| try: | |
| contents = os.listdir(image_encoder_base) | |
| update_progress(f" π image_encoder_full contains: {contents}") | |
| # Look for models--lambdalabs--sd-image-variations-diffusers | |
| hf_cache_dir = os.path.join(image_encoder_base, "models--lambdalabs--sd-image-variations-diffusers") | |
| if os.path.exists(hf_cache_dir): | |
| update_progress(f" β Found HF cache directory") | |
| # Navigate into snapshots | |
| snapshots_dir = os.path.join(hf_cache_dir, "snapshots") | |
| if os.path.exists(snapshots_dir): | |
| snapshot_dirs = [d for d in os.listdir(snapshots_dir) if os.path.isdir(os.path.join(snapshots_dir, d))] | |
| if snapshot_dirs: | |
| snapshot_path = os.path.join(snapshots_dir, snapshot_dirs[0]) | |
| update_progress(f" β Found snapshot: {snapshot_dirs[0]}") | |
| # Check for image_encoder subfolder | |
| img_enc_path = os.path.join(snapshot_path, "image_encoder") | |
| if os.path.exists(img_enc_path) and os.path.exists(os.path.join(img_enc_path, "config.json")): | |
| encoder_path = img_enc_path | |
| update_progress(f"β Found Image Encoder at: {img_enc_path}") | |
| elif os.path.exists(os.path.join(snapshot_path, "config.json")): | |
| encoder_path = snapshot_path | |
| update_progress(f"β Found Image Encoder at: {snapshot_path}") | |
| except Exception as e: | |
| update_progress(f" β οΈ Error navigating cache: {e}") | |
| # Fallback: Try direct paths | |
| if not encoder_path: | |
| fallback_paths = [ | |
| f"{MODEL_CACHE}/image_encoder_full/image_encoder", | |
| f"{MODEL_CACHE}/models--lambdalabs--sd-image-variations-diffusers/snapshots/*/image_encoder", | |
| ] | |
| for path_pattern in fallback_paths: | |
| if '*' in path_pattern: | |
| import glob | |
| matches = glob.glob(path_pattern) | |
| if matches and os.path.exists(os.path.join(matches[0], "config.json")): | |
| encoder_path = matches[0] | |
| update_progress(f"β Found Image Encoder via glob: {encoder_path}") | |
| break | |
| elif os.path.exists(path_pattern) and os.path.exists(os.path.join(path_pattern, "config.json")): | |
| encoder_path = path_pattern | |
| update_progress(f"β Found Image Encoder at: {path_pattern}") | |
| break | |
| mimo_weights_path = find_model_path( | |
| f"{MODEL_CACHE}/mimo_weights", | |
| "MIMO Weights", | |
| ["models--menyifang--MIMO"] | |
| ) | |
| # Validate required paths | |
| missing = [] | |
| if not vae_path: | |
| missing.append("VAE") | |
| update_progress(f"β VAE path not found") | |
| if not sd_path: | |
| missing.append("SD v1.5") | |
| update_progress(f"β SD v1.5 path not found") | |
| if not encoder_path: | |
| missing.append("Image Encoder") | |
| update_progress(f"β Image Encoder path not found") | |
| if not mimo_weights_path: | |
| missing.append("MIMO Weights") | |
| update_progress(f"β MIMO Weights path not found") | |
| if missing: | |
| error_msg = f"Missing required models: {', '.join(missing)}. Please run 'Setup Models' first." | |
| update_progress(f"β {error_msg}") | |
| # List what's actually in MODEL_CACHE to debug | |
| try: | |
| cache_contents = os.listdir(MODEL_CACHE) if os.path.exists(MODEL_CACHE) else [] | |
| update_progress(f"π MODEL_CACHE contents: {cache_contents[:15]}") | |
| except: | |
| pass | |
| return False | |
| update_progress("β All required models found") | |
| # Determine optimal settings | |
| if DEVICE == "cuda": | |
| try: | |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| self.weight_dtype = torch.float16 if gpu_memory > 10 else torch.float32 | |
| update_progress(f"Using {'FP16' if self.weight_dtype == torch.float16 else 'FP32'} on GPU ({gpu_memory:.1f}GB)") | |
| except Exception as e: | |
| update_progress(f"β οΈ GPU detection failed: {e}, using FP32") | |
| self.weight_dtype = torch.float32 | |
| else: | |
| self.weight_dtype = torch.float32 | |
| update_progress("Using FP32 on CPU") | |
| # Load VAE (keep on CPU for ZeroGPU) | |
| try: | |
| update_progress("Loading VAE...") | |
| vae = AutoencoderKL.from_pretrained( | |
| vae_path, | |
| torch_dtype=self.weight_dtype | |
| ) # Don't move to GPU yet | |
| update_progress("β VAE loaded (on CPU)") | |
| except Exception as e: | |
| update_progress(f"β VAE loading failed: {str(e)[:100]}") | |
| raise | |
| # Load 2D UNet (reference) - keep on CPU for ZeroGPU | |
| try: | |
| update_progress("Loading Reference UNet...") | |
| reference_unet = UNet2DConditionModel.from_pretrained( | |
| sd_path, | |
| subfolder="unet", | |
| torch_dtype=self.weight_dtype | |
| ) # Don't move to GPU yet | |
| update_progress("β Reference UNet loaded (on CPU)") | |
| except Exception as e: | |
| update_progress(f"β Reference UNet loading failed: {str(e)[:100]}") | |
| raise | |
| # Load inference config | |
| config_path = "./configs/inference/inference_v2.yaml" | |
| if os.path.exists(config_path): | |
| infer_config = OmegaConf.load(config_path) | |
| update_progress("β Loaded inference config") | |
| else: | |
| # Create complete fallback config matching original implementation | |
| update_progress("Creating fallback inference config...") | |
| infer_config = OmegaConf.create({ | |
| "unet_additional_kwargs": { | |
| "use_inflated_groupnorm": True, | |
| "unet_use_cross_frame_attention": False, | |
| "unet_use_temporal_attention": False, | |
| "use_motion_module": True, | |
| "motion_module_resolutions": [1, 2, 4, 8], | |
| "motion_module_mid_block": True, | |
| "motion_module_decoder_only": False, | |
| "motion_module_type": "Vanilla", | |
| "motion_module_kwargs": { | |
| "num_attention_heads": 8, | |
| "num_transformer_block": 1, | |
| "attention_block_types": ["Temporal_Self", "Temporal_Self"], | |
| "temporal_position_encoding": True, | |
| "temporal_position_encoding_max_len": 32, | |
| "temporal_attention_dim_div": 1 | |
| } | |
| }, | |
| "noise_scheduler_kwargs": { | |
| "beta_start": 0.00085, | |
| "beta_end": 0.012, | |
| "beta_schedule": "scaled_linear", | |
| "clip_sample": False, | |
| "steps_offset": 1, | |
| "prediction_type": "v_prediction", | |
| "rescale_betas_zero_snr": True, | |
| "timestep_spacing": "trailing" | |
| } | |
| }) | |
| # Load 3D UNet (denoising) - keep on CPU for ZeroGPU | |
| # NOTE: from_pretrained_2d is a custom MIMO method that doesn't accept torch_dtype | |
| try: | |
| update_progress("Loading Denoising UNet (3D)...") | |
| denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
| sd_path, | |
| "", # motion_module_path loaded separately | |
| subfolder="unet", | |
| unet_additional_kwargs=infer_config.unet_additional_kwargs | |
| ) | |
| # Convert dtype after loading since from_pretrained_2d doesn't accept torch_dtype | |
| denoising_unet = denoising_unet.to(dtype=self.weight_dtype) | |
| update_progress("β Denoising UNet loaded (on CPU)") | |
| except Exception as e: | |
| update_progress(f"β Denoising UNet loading failed: {str(e)[:100]}") | |
| raise | |
| # Load pose guider - keep on CPU for ZeroGPU | |
| try: | |
| update_progress("Loading Pose Guider...") | |
| pose_guider = PoseGuider( | |
| 320, | |
| conditioning_channels=3, | |
| block_out_channels=(16, 32, 96, 256) | |
| ).to(dtype=self.weight_dtype) # Don't move to GPU yet | |
| update_progress("β Pose Guider initialized (on CPU)") | |
| except Exception as e: | |
| update_progress(f"β Pose Guider loading failed: {str(e)[:100]}") | |
| raise | |
| # Load image encoder - keep on CPU for ZeroGPU | |
| try: | |
| update_progress("Loading CLIP Image Encoder...") | |
| image_enc = CLIPVisionModelWithProjection.from_pretrained( | |
| encoder_path, | |
| torch_dtype=self.weight_dtype | |
| ) # Don't move to GPU yet | |
| update_progress("β Image Encoder loaded (on CPU)") | |
| except Exception as e: | |
| update_progress(f"β Image Encoder loading failed: {str(e)[:100]}") | |
| raise | |
| # Load scheduler | |
| update_progress("Loading Scheduler...") | |
| sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) | |
| scheduler = DDIMScheduler(**sched_kwargs) | |
| # Load pretrained MIMO weights | |
| update_progress("Loading MIMO pretrained weights...") | |
| weight_files = list(Path(mimo_weights_path).rglob("*.pth")) | |
| if not weight_files: | |
| error_msg = f"No MIMO weight files (.pth) found at {mimo_weights_path}. Please run 'Setup Models' to download them." | |
| update_progress(f"β {error_msg}") | |
| return False | |
| update_progress(f"Found {len(weight_files)} weight files") | |
| weights_loaded = 0 | |
| for weight_file in weight_files: | |
| try: | |
| weight_name = weight_file.name | |
| if "denoising_unet" in weight_name: | |
| state_dict = torch.load(weight_file, map_location="cpu") | |
| denoising_unet.load_state_dict(state_dict, strict=False) | |
| update_progress(f"β Loaded {weight_name}") | |
| weights_loaded += 1 | |
| elif "reference_unet" in weight_name: | |
| state_dict = torch.load(weight_file, map_location="cpu") | |
| reference_unet.load_state_dict(state_dict) | |
| update_progress(f"β Loaded {weight_name}") | |
| weights_loaded += 1 | |
| elif "pose_guider" in weight_name: | |
| state_dict = torch.load(weight_file, map_location="cpu") | |
| pose_guider.load_state_dict(state_dict) | |
| update_progress(f"β Loaded {weight_name}") | |
| weights_loaded += 1 | |
| elif "motion_module" in weight_name: | |
| # Load motion module into denoising_unet | |
| state_dict = torch.load(weight_file, map_location="cpu") | |
| denoising_unet.load_state_dict(state_dict, strict=False) | |
| update_progress(f"β Loaded {weight_name}") | |
| weights_loaded += 1 | |
| except Exception as e: | |
| update_progress(f"β οΈ Failed to load {weight_file.name}: {str(e)[:100]}") | |
| print(f"Full error for {weight_file.name}: {e}") | |
| if weights_loaded == 0: | |
| error_msg = "No MIMO weights were successfully loaded" | |
| update_progress(f"β {error_msg}") | |
| return False | |
| update_progress(f"β Loaded {weights_loaded}/{len(weight_files)} weight files") | |
| # Create pipeline - keep on CPU for ZeroGPU | |
| try: | |
| update_progress("Creating MIMO pipeline...") | |
| self.pipe = Pose2VideoPipeline( | |
| vae=vae, | |
| image_encoder=image_enc, | |
| reference_unet=reference_unet, | |
| denoising_unet=denoising_unet, | |
| pose_guider=pose_guider, | |
| scheduler=scheduler, | |
| ).to(dtype=self.weight_dtype) # Keep on CPU, will move to GPU during inference | |
| # Enable memory-efficient attention for ZeroGPU | |
| if HAS_SPACES: | |
| try: | |
| # Enable gradient checkpointing to save memory | |
| if hasattr(denoising_unet, 'enable_gradient_checkpointing'): | |
| denoising_unet.enable_gradient_checkpointing() | |
| if hasattr(reference_unet, 'enable_gradient_checkpointing'): | |
| reference_unet.enable_gradient_checkpointing() | |
| # Try to enable xformers for memory efficiency | |
| try: | |
| self.pipe.enable_xformers_memory_efficient_attention() | |
| update_progress("β Memory-efficient attention enabled") | |
| except: | |
| update_progress("β οΈ xformers not available, using standard attention") | |
| except Exception as e: | |
| update_progress(f"β οΈ Could not enable memory optimizations: {str(e)[:50]}") | |
| update_progress("β Pipeline created (on CPU - will use GPU during generation)") | |
| except Exception as e: | |
| update_progress(f"β Pipeline creation failed: {str(e)[:100]}") | |
| raise | |
| # Load human segmenter | |
| update_progress("Loading human segmenter...") | |
| if HAS_SEGMENTER: | |
| seg_path = f"{ASSETS_CACHE}/matting_human.pb" | |
| if os.path.exists(seg_path): | |
| try: | |
| self.segmenter = human_segmenter(model_path=seg_path) | |
| update_progress("β Human segmenter loaded") | |
| except Exception as e: | |
| update_progress(f"β οΈ Segmenter load failed: {e}, using fallback") | |
| self.segmenter = None | |
| else: | |
| update_progress("β οΈ Segmenter model not found, using fallback") | |
| self.segmenter = None | |
| else: | |
| update_progress("β οΈ TensorFlow not available, using fallback segmentation") | |
| self.segmenter = None | |
| # Load mask templates | |
| update_progress("Loading mask templates...") | |
| mask_path = f"{ASSETS_CACHE}/masks/alpha2.png" | |
| if os.path.exists(mask_path): | |
| self.mask_list = load_mask_list(mask_path) | |
| update_progress("β Mask templates loaded") | |
| else: | |
| # Create fallback masks | |
| update_progress("Creating fallback masks...") | |
| os.makedirs(f"{ASSETS_CACHE}/masks", exist_ok=True) | |
| fallback_mask = np.ones((512, 512), dtype=np.uint8) * 255 | |
| self.mask_list = [fallback_mask] | |
| self.is_loaded = True | |
| update_progress("π MIMO model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| update_progress(f"β Model loading failed: {e}") | |
| traceback.print_exc() | |
| return False | |
| def process_image(self, image): | |
| """Process input image with human segmentation (matching run_edit.py/run_animate.py)""" | |
| if self.segmenter is None: | |
| # Fallback: just resize and center | |
| image = np.array(image) | |
| image = cv2.resize(image, (512, 512)) | |
| return Image.fromarray(image), None | |
| try: | |
| img_array = np.array(image) | |
| # Use BGR for segmenter (as in original code) | |
| rgba = self.segmenter.run(img_array[..., ::-1]) | |
| mask = rgba[:, :, 3] | |
| color = rgba[:, :, :3] | |
| alpha = mask / 255 | |
| bk = np.ones_like(color) * 255 | |
| color = color * alpha[:, :, np.newaxis] + bk * (1 - alpha[:, :, np.newaxis]) | |
| color = color.astype(np.uint8) | |
| # Convert back to RGB | |
| color = color[..., ::-1] | |
| # Crop and pad like original code | |
| color = crop_img(color, mask) | |
| color, _ = pad_img(color, [255, 255, 255]) | |
| return Image.fromarray(color), mask | |
| except Exception as e: | |
| print(f"β οΈ Segmentation failed, using original image: {e}") | |
| return image, None | |
| def get_available_templates(self): | |
| """Get list of available video templates""" | |
| template_dir = "./assets/video_template" | |
| # Create directory if it doesn't exist | |
| if not os.path.exists(template_dir): | |
| os.makedirs(template_dir, exist_ok=True) | |
| print(f"β οΈ Video template directory created: {template_dir}") | |
| print("π‘ Tip: Download templates from HuggingFace repo or use 'Setup Models' button") | |
| return [] | |
| templates = [] | |
| try: | |
| for item in os.listdir(template_dir): | |
| template_path = os.path.join(template_dir, item) | |
| if os.path.isdir(template_path): | |
| # Check if it has required files | |
| sdc_file = os.path.join(template_path, "sdc.mp4") | |
| if os.path.exists(sdc_file): # At minimum need pose video | |
| templates.append(item) | |
| except Exception as e: | |
| print(f"β οΈ Error scanning templates: {e}") | |
| return [] | |
| if not templates: | |
| print("β οΈ No video templates found. Click 'Setup Models' to download.") | |
| return sorted(templates) | |
| def load_template(self, template_path: str) -> Dict: | |
| """Load template metadata (matching run_edit.py logic)""" | |
| try: | |
| video_path = os.path.join(template_path, 'vid.mp4') | |
| pose_video_path = os.path.join(template_path, 'sdc.mp4') | |
| bk_video_path = os.path.join(template_path, 'bk.mp4') | |
| occ_video_path = os.path.join(template_path, 'occ.mp4') | |
| # Check occlusion masks | |
| if not os.path.exists(occ_video_path): | |
| occ_video_path = None | |
| # Load config if available | |
| config_file = os.path.join(template_path, 'config.json') | |
| if os.path.exists(config_file): | |
| with open(config_file) as f: | |
| template_data = json.load(f) | |
| return { | |
| 'video_path': video_path, | |
| 'pose_video_path': pose_video_path, | |
| 'bk_video_path': bk_video_path if os.path.exists(bk_video_path) else None, | |
| 'occ_video_path': occ_video_path, | |
| 'target_fps': template_data.get('fps', 30), | |
| 'time_crop': template_data.get('time_crop', {'start_idx': 0, 'end_idx': -1}), | |
| 'frame_crop': template_data.get('frame_crop', {}), | |
| 'layer_recover': template_data.get('layer_recover', True) | |
| } | |
| else: | |
| # Fallback for templates without config | |
| return { | |
| 'video_path': video_path if os.path.exists(video_path) else None, | |
| 'pose_video_path': pose_video_path, | |
| 'bk_video_path': bk_video_path if os.path.exists(bk_video_path) else None, | |
| 'occ_video_path': occ_video_path, | |
| 'target_fps': 30, | |
| 'time_crop': {'start_idx': 0, 'end_idx': -1}, | |
| 'frame_crop': {}, | |
| 'layer_recover': True | |
| } | |
| except Exception as e: | |
| print(f"β οΈ Failed to load template config: {e}") | |
| return None | |
| def generate_animation(self, input_image, template_name, mode="edit", progress_callback=None): | |
| """Generate video animation (implementing both run_edit.py and run_animate.py logic)""" | |
| def update_progress(msg): | |
| if progress_callback: | |
| progress_callback(msg) | |
| print(f"π¬ {msg}") | |
| try: | |
| if not self.is_loaded: | |
| update_progress("Loading model first...") | |
| if not self.load_model(progress_callback): | |
| return None, "β Model loading failed" | |
| # Move pipeline to GPU if using ZeroGPU (only during inference) | |
| if HAS_SPACES and torch.cuda.is_available(): | |
| update_progress("Moving models to GPU...") | |
| self.pipe = self.pipe.to("cuda") | |
| update_progress("β Models on GPU") | |
| # Process input image | |
| update_progress("Processing input image...") | |
| processed_image, mask = self.process_image(input_image) | |
| # Load template | |
| template_path = f"./assets/video_template/{template_name}" | |
| if not os.path.exists(template_path): | |
| return None, f"β Template '{template_name}' not found" | |
| template_info = self.load_template(template_path) | |
| if template_info is None: | |
| return None, f"β Failed to load template '{template_name}'" | |
| update_progress(f"Loaded template: {template_name}") | |
| # Load video components | |
| target_fps = template_info['target_fps'] | |
| pose_video_path = template_info['pose_video_path'] | |
| if not os.path.exists(pose_video_path): | |
| return None, f"β Pose video not found: {pose_video_path}" | |
| # Load pose sequence | |
| update_progress("Loading motion sequence...") | |
| pose_images = load_video_fixed_fps(pose_video_path, target_fps=target_fps) | |
| # Load background if available | |
| bk_video_path = template_info['bk_video_path'] | |
| if bk_video_path and os.path.exists(bk_video_path): | |
| bk_images = load_video_fixed_fps(bk_video_path, target_fps=target_fps) | |
| update_progress("β Loaded background video") | |
| else: | |
| # Create white background | |
| n_frame = len(pose_images) | |
| tw, th = pose_images[0].size | |
| bk_images = [] | |
| for _ in range(n_frame): | |
| bk_img = Image.new('RGB', (tw, th), (255, 255, 255)) | |
| bk_images.append(bk_img) | |
| update_progress("β Created white background") | |
| # Load occlusion masks if available (for advanced editing) | |
| occ_video_path = template_info['occ_video_path'] | |
| if occ_video_path and os.path.exists(occ_video_path) and mode == "edit": | |
| occ_mask_images = load_video_fixed_fps(occ_video_path, target_fps=target_fps) | |
| update_progress("β Loaded occlusion masks") | |
| else: | |
| occ_mask_images = None | |
| # Apply time cropping | |
| time_crop = template_info['time_crop'] | |
| start_idx = max(0, int(target_fps * time_crop['start_idx'] / 30)) if time_crop['start_idx'] >= 0 else 0 | |
| end_idx = min(len(pose_images), int(target_fps * time_crop['end_idx'] / 30)) if time_crop['end_idx'] >= 0 else len(pose_images) | |
| pose_images = pose_images[start_idx:end_idx] | |
| bk_images = bk_images[start_idx:end_idx] | |
| if occ_mask_images: | |
| occ_mask_images = occ_mask_images[start_idx:end_idx] | |
| # Limit max frames for memory - REDUCED for ZeroGPU (22GB limit) | |
| # ZeroGPU has limited memory, so we reduce from 150 to 100 frames | |
| MAX_FRAMES = 100 if HAS_SPACES else 150 | |
| if len(pose_images) > MAX_FRAMES: | |
| update_progress(f"β οΈ Limiting to {MAX_FRAMES} frames to fit in GPU memory") | |
| pose_images = pose_images[:MAX_FRAMES] | |
| bk_images = bk_images[:MAX_FRAMES] | |
| if occ_mask_images: | |
| occ_mask_images = occ_mask_images[:MAX_FRAMES] | |
| num_frames = len(pose_images) | |
| update_progress(f"Processing {num_frames} frames...") | |
| if mode == "animate": | |
| # Simple animation mode (run_animate.py logic) | |
| pose_list = [] | |
| vid_bk_list = [] | |
| # Crop pose with human-center | |
| pose_images, _, bk_images = crop_human(pose_images, pose_images.copy(), bk_images) | |
| for frame_idx in range(len(pose_images)): | |
| pose_image = np.array(pose_images[frame_idx]) | |
| pose_image, _ = pad_img(pose_image, color=[0, 0, 0]) | |
| pose_list.append(Image.fromarray(pose_image)) | |
| vid_bk = np.array(bk_images[frame_idx]) | |
| vid_bk, _ = pad_img(vid_bk, color=[255, 255, 255]) | |
| vid_bk_list.append(Image.fromarray(vid_bk)) | |
| # Generate video | |
| update_progress("Generating animation...") | |
| width, height = 512, 512 # Optimized for HF | |
| steps = 20 # Balanced quality/speed | |
| cfg = 3.5 | |
| generator = torch.Generator(device=DEVICE).manual_seed(42) | |
| video = self.pipe( | |
| processed_image, | |
| pose_list, | |
| vid_bk_list, | |
| width, | |
| height, | |
| num_frames, | |
| steps, | |
| cfg, | |
| generator=generator, | |
| ).videos[0] | |
| # Convert to output format | |
| update_progress("Post-processing video...") | |
| res_images = [] | |
| for video_idx in range(num_frames): | |
| image = video[:, video_idx, :, :].permute(1, 2, 0).cpu().numpy() | |
| res_image_pil = Image.fromarray((image * 255).astype(np.uint8)) | |
| res_images.append(res_image_pil) | |
| else: | |
| # Advanced editing mode (run_edit.py logic) | |
| update_progress("Advanced video editing mode...") | |
| # Load original video for blending | |
| video_path = template_info['video_path'] | |
| if video_path and os.path.exists(video_path): | |
| vid_images = load_video_fixed_fps(video_path, target_fps=target_fps) | |
| vid_images = vid_images[start_idx:end_idx][:MAX_FRAMES] | |
| else: | |
| vid_images = pose_images.copy() | |
| # Advanced crop with context for seamless blending | |
| overlay = 4 | |
| pose_images, vid_images, bk_images, bbox_clip, context_list, bbox_clip_list = crop_human_clip_auto_context( | |
| pose_images, vid_images, bk_images, overlay) | |
| # Process each frame | |
| clip_pad_list_context = [] | |
| clip_padv_list_context = [] | |
| pose_list_context = [] | |
| vid_bk_list_context = [] | |
| for frame_idx in range(len(pose_images)): | |
| pose_image = np.array(pose_images[frame_idx]) | |
| pose_image, _ = pad_img(pose_image, color=[0, 0, 0]) | |
| pose_list_context.append(Image.fromarray(pose_image)) | |
| vid_bk = np.array(bk_images[frame_idx]) | |
| vid_bk, padding_v = pad_img(vid_bk, color=[255, 255, 255]) | |
| pad_h, pad_w, _ = vid_bk.shape | |
| clip_pad_list_context.append([pad_h, pad_w]) | |
| clip_padv_list_context.append(padding_v) | |
| vid_bk_list_context.append(Image.fromarray(vid_bk)) | |
| # Generate video with advanced settings | |
| width, height = 784, 784 # Higher resolution for editing | |
| steps = 25 # Higher quality | |
| cfg = 3.5 | |
| generator = torch.Generator(device=DEVICE).manual_seed(42) | |
| video = self.pipe( | |
| processed_image, | |
| pose_list_context, | |
| vid_bk_list_context, | |
| width, | |
| height, | |
| len(pose_list_context), | |
| steps, | |
| cfg, | |
| generator=generator, | |
| ).videos[0] | |
| # Advanced post-processing with blending and occlusion | |
| update_progress("Advanced post-processing...") | |
| vid_images_ori = vid_images.copy() | |
| bk_images_ori = bk_images.copy() | |
| video_idx = 0 | |
| res_images = [None for _ in range(len(pose_images))] | |
| for k, context in enumerate(context_list): | |
| start_i = context[0] | |
| bbox = bbox_clip_list[k] | |
| for i in context: | |
| bk_image_pil_ori = bk_images_ori[i] | |
| vid_image_pil_ori = vid_images_ori[i] | |
| occ_mask = occ_mask_images[i] if occ_mask_images else None | |
| canvas = Image.new("RGB", bk_image_pil_ori.size, "white") | |
| pad_h, pad_w = clip_pad_list_context[video_idx] | |
| padding_v = clip_padv_list_context[video_idx] | |
| image = video[:, video_idx, :, :].permute(1, 2, 0).cpu().numpy() | |
| res_image_pil = Image.fromarray((image * 255).astype(np.uint8)) | |
| res_image_pil = res_image_pil.resize((pad_w, pad_h)) | |
| top, bottom, left, right = padding_v | |
| res_image_pil = res_image_pil.crop((left, top, pad_w - right, pad_h - bottom)) | |
| w_min, w_max, h_min, h_max = bbox | |
| canvas.paste(res_image_pil, (w_min, h_min)) | |
| # Apply mask blending with bounds checking | |
| mask_full = np.zeros((bk_image_pil_ori.size[1], bk_image_pil_ori.size[0]), dtype=np.float32) | |
| mask = get_mask(self.mask_list, bbox, bk_image_pil_ori) | |
| mask = cv2.resize(mask, res_image_pil.size, interpolation=cv2.INTER_AREA) | |
| # Clip mask to fit within canvas bounds | |
| canvas_h, canvas_w = mask_full.shape | |
| mask_h, mask_w = mask.shape | |
| # Calculate actual region that fits | |
| h_end = min(h_min + mask_h, canvas_h) | |
| w_end = min(w_min + mask_w, canvas_w) | |
| # Clip mask if it exceeds bounds | |
| actual_h = h_end - h_min | |
| actual_w = w_end - w_min | |
| mask_full[h_min:h_end, w_min:w_end] = mask[:actual_h, :actual_w] | |
| res_image = np.array(canvas) | |
| bk_image = np.array(bk_image_pil_ori) | |
| res_image = res_image * mask_full[:, :, np.newaxis] + bk_image * (1 - mask_full[:, :, np.newaxis]) | |
| # Apply occlusion masks if available | |
| if occ_mask is not None: | |
| vid_image = np.array(vid_image_pil_ori) | |
| occ_mask_array = np.array(occ_mask)[:, :, 0].astype(np.uint8) | |
| occ_mask_array = occ_mask_array / 255.0 | |
| # Resize occlusion mask to match res_image dimensions | |
| if occ_mask_array.shape[:2] != res_image.shape[:2]: | |
| occ_mask_array = cv2.resize(occ_mask_array, (res_image.shape[1], res_image.shape[0]), interpolation=cv2.INTER_LINEAR) | |
| # Also resize vid_image to match res_image dimensions | |
| if vid_image.shape[:2] != res_image.shape[:2]: | |
| vid_image = cv2.resize(vid_image, (res_image.shape[1], res_image.shape[0]), interpolation=cv2.INTER_LINEAR) | |
| res_image = res_image * (1 - occ_mask_array[:, :, np.newaxis]) + vid_image * occ_mask_array[:, :, np.newaxis] | |
| # Blend overlapping regions | |
| if res_images[i] is None: | |
| res_images[i] = res_image | |
| else: | |
| factor = (i - start_i + 1) / (overlay + 1) | |
| res_images[i] = res_images[i] * (1 - factor) + res_image * factor | |
| res_images[i] = res_images[i].astype(np.uint8) | |
| video_idx += 1 | |
| # Ensure all frames have even dimensions (required for H.264 encoding) | |
| update_progress("Finalizing video encoding...") | |
| for i, frame in enumerate(res_images): | |
| if frame is not None: | |
| h, w = frame.shape[:2] | |
| # Make dimensions even by cropping 1 pixel if odd | |
| new_h = h if h % 2 == 0 else h - 1 | |
| new_w = w if w % 2 == 0 else w - 1 | |
| if new_h != h or new_w != w: | |
| res_images[i] = frame[:new_h, :new_w] | |
| # Save output video with error handling | |
| output_path = f"./output/mimo_output_{int(time.time())}.mp4" | |
| try: | |
| imageio.mimsave(output_path, res_images, fps=target_fps, quality=8, macro_block_size=1) | |
| except (OSError, BrokenPipeError) as e: | |
| # FFMPEG encoding failed, try with more compatible settings | |
| update_progress("β οΈ Retrying with compatible encoding settings...") | |
| try: | |
| # Use PIL to save as GIF instead (more reliable) | |
| gif_path = output_path.replace('.mp4', '.gif') | |
| imageio.mimsave(gif_path, res_images, fps=target_fps, duration=1000/target_fps) | |
| output_path = gif_path | |
| update_progress("β Saved as GIF (FFMPEG encoding failed)") | |
| except Exception as gif_error: | |
| raise Exception(f"Video encoding failed: {str(e)}. GIF fallback also failed: {str(gif_error)}") | |
| # CRITICAL: Move pipeline back to CPU and clear GPU cache for ZeroGPU | |
| if HAS_SPACES and torch.cuda.is_available(): | |
| update_progress("Cleaning up GPU memory...") | |
| self.pipe = self.pipe.to("cpu") | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| update_progress("β GPU memory released") | |
| update_progress("β Video generated successfully!") | |
| return output_path, f"π Generated {len(res_images)} frames at {target_fps}fps using {mode} mode!" | |
| except Exception as e: | |
| # CRITICAL: Always clean up GPU memory on error | |
| if HAS_SPACES and torch.cuda.is_available(): | |
| try: | |
| self.pipe = self.pipe.to("cpu") | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| print("β GPU memory cleaned up after error") | |
| except: | |
| pass | |
| error_msg = f"β Generation failed: {e}" | |
| update_progress(error_msg) | |
| traceback.print_exc() | |
| return None, error_msg | |
| # Initialize global model | |
| mimo_model = CompleteMIMO() | |
| def gradio_interface(): | |
| """Create complete Gradio interface matching README_SETUP.md functionality""" | |
| def setup_models(progress=gr.Progress()): | |
| """Setup models with progress tracking""" | |
| try: | |
| # Download models | |
| progress(0.1, desc="Starting download...") | |
| download_success = mimo_model.download_models(lambda msg: progress(0.3, desc=msg)) | |
| if not download_success: | |
| return "β οΈ Some downloads failed. Check logs for details. You may still be able to use the app with partial functionality." | |
| # Load models immediately after download | |
| progress(0.6, desc="Loading models...") | |
| load_success = mimo_model.load_model(lambda msg: progress(0.8, desc=msg)) | |
| if not load_success: | |
| return "β Model loading failed. Please check the logs and try again." | |
| progress(1.0, desc="β Ready!") | |
| return "π MIMO is ready! Models loaded successfully. Upload an image and select a template to start." | |
| except Exception as e: | |
| error_details = str(e) | |
| print(f"Setup error: {error_details}") | |
| traceback.print_exc() | |
| return f"β Setup failed: {error_details[:200]}" | |
| # Decorate with @spaces.GPU for ZeroGPU support | |
| if HAS_SPACES: | |
| # Allow 120 seconds on GPU | |
| def generate_video_gradio(input_image, template_name, mode, progress=gr.Progress()): | |
| """Gradio wrapper for video generation""" | |
| if input_image is None: | |
| return None, "Please upload an image first" | |
| if not template_name: | |
| return None, "Please select a motion template" | |
| try: | |
| progress(0.1, desc="Starting generation...") | |
| def progress_callback(msg): | |
| progress(0.5, desc=msg) | |
| output_path, message = mimo_model.generate_animation( | |
| input_image, | |
| template_name, | |
| mode, | |
| progress_callback | |
| ) | |
| progress(1.0, desc="Complete!") | |
| return output_path, message | |
| except Exception as e: | |
| return None, f"β Generation failed: {e}" | |
| else: | |
| # Local mode without GPU decorator | |
| def generate_video_gradio(input_image, template_name, mode, progress=gr.Progress()): | |
| """Gradio wrapper for video generation""" | |
| if input_image is None: | |
| return None, "Please upload an image first" | |
| if not template_name: | |
| return None, "Please select a motion template" | |
| try: | |
| progress(0.1, desc="Starting generation...") | |
| def progress_callback(msg): | |
| progress(0.5, desc=msg) | |
| output_path, message = mimo_model.generate_animation( | |
| input_image, | |
| template_name, | |
| mode, | |
| progress_callback | |
| ) | |
| progress(1.0, desc="Complete!") | |
| return output_path, message | |
| except Exception as e: | |
| return None, f"β Generation failed: {e}" | |
| def refresh_templates(): | |
| """Refresh available templates""" | |
| templates = mimo_model.get_available_templates() | |
| return gr.Dropdown(choices=templates, value=templates[0] if templates else None) | |
| # Create Gradio blocks | |
| with gr.Blocks( | |
| title="MIMO - Complete Character Video Synthesis", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1400px; | |
| margin: auto; | |
| } | |
| .header { | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| color: #1a1a1a !important; | |
| } | |
| .header h1 { | |
| color: #2c3e50 !important; | |
| margin-bottom: 0.5rem; | |
| font-weight: 700; | |
| } | |
| .header p { | |
| color: #34495e !important; | |
| margin: 0.5rem 0; | |
| font-weight: 500; | |
| } | |
| .header a { | |
| color: #3498db !important; | |
| text-decoration: none; | |
| margin: 0 0.5rem; | |
| font-weight: 600; | |
| } | |
| .header a:hover { | |
| text-decoration: underline; | |
| color: #2980b9 !important; | |
| } | |
| .mode-info { | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| border-radius: 8px; | |
| color: #2c3e50 !important; | |
| } | |
| .mode-info h4 { | |
| margin-top: 0; | |
| color: #2c3e50 !important; | |
| font-weight: 700; | |
| } | |
| .mode-info p { | |
| margin: 0.5rem 0; | |
| color: #34495e !important; | |
| font-weight: 500; | |
| } | |
| .mode-info strong { | |
| color: #1a1a1a !important; | |
| font-weight: 700; | |
| } | |
| .mode-animate { | |
| background: #e8f5e8; | |
| border-left: 4px solid #4caf50; | |
| } | |
| .mode-edit { | |
| background: #e3f2fd; | |
| border-left: 4px solid #2196f3; | |
| } | |
| .warning-box { | |
| padding: 1rem; | |
| background: #fff3cd; | |
| border-left: 4px solid #ffc107; | |
| margin: 1rem 0; | |
| border-radius: 4px; | |
| } | |
| .warning-box b { | |
| color: #856404 !important; | |
| font-weight: 700; | |
| } | |
| .warning-box br + text, .warning-box { | |
| color: #856404 !important; | |
| } | |
| .warning-box, .warning-box * { | |
| color: #856404 !important; | |
| } | |
| .instructions-box { | |
| margin-top: 2rem; | |
| padding: 1.5rem; | |
| background: #f8f9fa; | |
| border-radius: 8px; | |
| border: 1px solid #dee2e6; | |
| } | |
| .instructions-box h4 { | |
| color: #2c3e50 !important; | |
| margin-top: 1rem; | |
| margin-bottom: 0.5rem; | |
| font-weight: 700; | |
| } | |
| .instructions-box h4:first-child { | |
| margin-top: 0; | |
| } | |
| .instructions-box ol { | |
| color: #495057 !important; | |
| line-height: 1.8; | |
| } | |
| .instructions-box ol li { | |
| margin: 0.5rem 0; | |
| color: #495057 !important; | |
| } | |
| .instructions-box ol li strong { | |
| color: #1a1a1a !important; | |
| font-weight: 700; | |
| } | |
| .instructions-box p { | |
| color: #495057 !important; | |
| margin: 0.3rem 0; | |
| line-height: 1.6; | |
| } | |
| .instructions-box p strong { | |
| color: #1a1a1a !important; | |
| font-weight: 700; | |
| } | |
| """ | |
| ) as demo: | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>π¬ MIMO - Complete Character Video Synthesis</h1> | |
| <p>Full implementation matching the original research paper - Character Animation & Video Editing</p> | |
| <p> | |
| <a href="https://menyifang.github.io/projects/MIMO/index.html">π Project Page</a> | | |
| <a href="https://github.com/menyifang/MIMO">π» GitHub</a> | | |
| <a href="https://arxiv.org/abs/2409.16160">π Paper</a> | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>πΌοΈ Input Configuration</h3>") | |
| input_image = gr.Image( | |
| label="Character Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| mode = gr.Radio( | |
| label="Generation Mode", | |
| choices=[ | |
| ("π Character Animation", "animate"), | |
| ("π¬ Video Character Editing", "edit") | |
| ], | |
| value="edit" | |
| ) | |
| # Dynamic template loading | |
| templates = mimo_model.get_available_templates() | |
| if not templates: | |
| gr.HTML(""" | |
| <div class="warning-box"> | |
| <b>β οΈ No Motion Templates Found</b><br/> | |
| Click <b>"π§ Setup Models"</b> button below to download video templates.<br/> | |
| Templates will be downloaded to: <code>./assets/video_template/</code> | |
| </div> | |
| """) | |
| motion_template = gr.Dropdown( | |
| label="Motion Template", | |
| choices=templates if templates else ["No templates - Upload manually or use reference image only"], | |
| value=templates[0] if templates else None, | |
| info="Templates provide motion guidance. Not required for basic image animation." | |
| ) | |
| with gr.Row(): | |
| setup_btn = gr.Button("οΏ½ Setup Models", variant="secondary", scale=1) | |
| load_btn = gr.Button("β‘ Load Model", variant="secondary", scale=1) | |
| with gr.Row(): | |
| refresh_btn = gr.Button("οΏ½ Refresh Templates", variant="secondary", scale=1) | |
| generate_btn = gr.Button("π¬ Generate Video", variant="primary", scale=2) | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>π₯ Output</h3>") | |
| output_video = gr.Video( | |
| label="Generated Video", | |
| height=400 | |
| ) | |
| status_text = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=4 | |
| ) | |
| # Mode information | |
| gr.HTML(""" | |
| <div class="mode-info mode-animate"> | |
| <h4>π Character Animation Mode</h4> | |
| <p><strong>Features:</strong> Character image + motion template β animated video</p> | |
| <p><strong>Use case:</strong> Animate static characters with predefined motions</p> | |
| <p><strong>Based on:</strong> run_animate.py functionality</p> | |
| </div> | |
| <div class="mode-info mode-edit"> | |
| <h4>π¬ Video Character Editing Mode</h4> | |
| <p><strong>Features:</strong> Advanced editing with background blending, occlusion handling</p> | |
| <p><strong>Use case:</strong> Replace characters in existing videos while preserving backgrounds</p> | |
| <p><strong>Based on:</strong> run_edit.py functionality</p> | |
| </div> | |
| """) | |
| gr.HTML(""" | |
| <div class="instructions-box"> | |
| <h4>π Instructions:</h4> | |
| <ol> | |
| <li><strong>First Time Setup:</strong> Click "π§ Setup Models" to download MIMO (~8GB, one-time)</li> | |
| <li><strong>Load Model:</strong> Click "β‘ Load Model" to activate the model (required once per session)</li> | |
| <li><strong>Upload Image:</strong> Upload a character image (clear, front-facing works best)</li> | |
| <li><strong>Select Mode:</strong> Choose between Animation (simpler) or Editing (advanced)</li> | |
| <li><strong>Pick Template:</strong> Select a motion template from the dropdown (or refresh to see new ones)</li> | |
| <li><strong>Generate:</strong> Click "π¬ Generate Video" and wait for processing</li> | |
| </ol> | |
| <h4>π― Available Templates (11 total):</h4> | |
| <p><strong>Sports:</strong> basketball_gym, nba_dunk, nba_pass, football</p> | |
| <p><strong>Action:</strong> kungfu_desert, kungfu_match, parkour_climbing, BruceLee</p> | |
| <p><strong>Dance:</strong> dance_indoor, irish_dance</p> | |
| <p><strong>Synthetic:</strong> syn_basketball, syn_dancing, syn_football</p> | |
| <p><strong>π‘ Model Persistence:</strong> Downloaded models persist across page refreshes! Just click "Load Model" to reactivate.</p> | |
| <p><strong>β οΈ Timing:</strong> First setup takes 5-10 minutes. Model loading takes 30-60 seconds. Generation takes 2-5 minutes per video.</p> | |
| </div> | |
| """) | |
| # Event handlers | |
| def load_model_only(progress=gr.Progress()): | |
| """Load models without downloading (if already cached)""" | |
| try: | |
| # First check if already loaded | |
| if mimo_model.is_loaded: | |
| return "β Model already loaded and ready! You can generate videos now." | |
| # Re-check cache validity (in case models were just downloaded) | |
| mimo_model._check_existing_models() | |
| if not mimo_model._model_cache_valid: | |
| return "β οΈ Models not found in cache. Please click 'π§ Setup Models' first to download (~8GB)." | |
| progress(0.3, desc="Loading models from cache...") | |
| load_success = mimo_model.load_model(lambda msg: progress(0.7, desc=msg)) | |
| if load_success: | |
| progress(1.0, desc="β Ready!") | |
| return "β Model loaded successfully! Ready to generate videos. Upload an image and select a template." | |
| else: | |
| return "β Model loading failed. Check logs for details or try 'Setup Models' button." | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"β Load failed: {str(e)[:200]}" | |
| setup_btn.click( | |
| fn=setup_models, | |
| outputs=[status_text] | |
| ) | |
| load_btn.click( | |
| fn=load_model_only, | |
| outputs=[status_text] | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_templates, | |
| outputs=[motion_template] | |
| ) | |
| generate_btn.click( | |
| fn=generate_video_gradio, | |
| inputs=[input_image, motion_template, mode], | |
| outputs=[output_video, status_text] | |
| ) | |
| # Load examples (only if files exist) | |
| example_files = [ | |
| ["./assets/test_image/sugar.jpg", "sports_basketball_gym", "animate"], | |
| ["./assets/test_image/avatar.jpg", "dance_indoor_1", "animate"], | |
| ["./assets/test_image/cartoon1.png", "shorts_kungfu_desert1", "edit"], | |
| ["./assets/test_image/actorhq_A7S1.png", "syn_basketball_06_13", "edit"], | |
| ] | |
| # Filter examples to only include files that exist | |
| valid_examples = [ex for ex in example_files if os.path.exists(ex[0])] | |
| if valid_examples: | |
| gr.Examples( | |
| examples=valid_examples, | |
| inputs=[input_image, motion_template, mode], | |
| label="π― Examples" | |
| ) | |
| else: | |
| print("β οΈ No example images found, skipping examples section") | |
| return demo | |
| if __name__ == "__main__": | |
| # HF Spaces optimization - no auto-download to prevent timeout | |
| if os.getenv("SPACE_ID"): | |
| print("π Running on HuggingFace Spaces") | |
| print("π¦ Models will download on first use to prevent build timeout") | |
| else: | |
| print("π» Running locally") | |
| # Launch Gradio | |
| demo = gradio_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |