import os import torch import spaces from diffusers import StableVideoDiffusionPipeline from transformers import AutoProcessor, AutoModel import cv2 import numpy as np from PIL import Image from typing import Tuple, Optional, Any import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) MODEL_ID = "Orange-3DV-Team/MoCha" @spaces.GPU(duration=1800) # 30 minutes for model loading def load_model() -> Any: """ Load the MoCha model for video character replacement. Returns: Loaded model instance """ try: logger.info(f"Loading MoCha model: {MODEL_ID}") # Load the base Stable Video Diffusion model pipe = StableVideoDiffusionPipeline.from_pretrained( "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16" ) pipe.to("cuda") # Load additional components specific to MoCha try: processor = AutoProcessor.from_pretrained(MODEL_ID) character_model = AutoModel.from_pretrained(MODEL_ID) character_model.to("cuda") logger.info("MoCha character model loaded successfully") except Exception as e: logger.warning(f"Could not load MoCha-specific components: {e}") processor = None character_model = None # Enable memory efficient attention if available if hasattr(pipe, 'enable_attention_slicing'): pipe.enable_attention_slicing() if hasattr(pipe, 'enable_model_cpu_offload'): pipe.enable_model_cpu_offload() logger.info("Model loading completed successfully") return { 'pipe': pipe, 'processor': processor, 'character_model': character_model, 'device': 'cuda' } except Exception as e: logger.error(f"Error loading model: {e}") raise RuntimeError(f"Failed to load MoCha model: {e}") @spaces.GPU(duration=600) # 10 minutes per video processing def process_video_character_replacement( model_dict: dict, reference_image: Image.Image, video_path: str, output_dir: str ) -> Optional[str]: """ Process video with character replacement using MoCha model. Args: model_dict: Dictionary containing loaded models reference_image: PIL Image of target character video_path: Path to source video output_dir: Directory to save processed video Returns: Path to processed video or None if failed """ try: pipe = model_dict['pipe'] processor = model_dict['processor'] character_model = model_dict['character_model'] device = model_dict['device'] logger.info("Starting video character replacement process") # Read video and extract frames cap = cv2.VideoCapture(video_path) fps = int(cap.get(cv2.CAP_PROP_FPS)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) logger.info(f"Video info: {width}x{height}, {fps} FPS, {total_frames} frames") # Prepare reference image if reference_image.mode != 'RGB': reference_image = reference_image.convert('RGB') # Resize reference image to match expected input size reference_image = reference_image.resize((224, 224)) processed_frames = [] frame_count = 0 while True: ret, frame = cap.read() if not ret: break # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_frame = Image.fromarray(frame_rgb) # Process frame with character replacement try: # Use the reference character for replacement if character_model is not None and processor is not None: # Process with MoCha's character model inputs = processor( images=[reference_image, pil_frame], return_tensors="pt" ).to(device) with torch.no_grad(): # Generate character-replaced frame output = character_model.generate( **inputs, num_frames=1, guidance_scale=7.5, num_inference_steps=20 ) processed_frame = processor.post_process( output, output_type="pil" )[0] else: # Fallback to stable video diffusion with character guidance pil_frame = pil_frame.resize((1024, 576)) video_frames = pipe( reference_image, decode_chunk_size=8, num_frames=14, guidance_scale=3.0, num_inference_steps=25, motion_bucket_id=127, noise_aug_strength=0.02, image=reference_image # Use reference for character guidance ).frames[0] processed_frame = video_frames[0] # Resize back to original dimensions if processed_frame.size != (width, height): processed_frame = processed_frame.resize((width, height)) processed_frames.append(processed_frame) except Exception as e: logger.warning(f"Error processing frame {frame_count}: {e}") # Keep original frame if processing fails processed_frames.append(pil_frame) frame_count += 1 if frame_count % 10 == 0: logger.info(f"Processed {frame_count}/{total_frames} frames") cap.release() # Save processed video if processed_frames: output_path = os.path.join(output_dir, "character_replaced_video.mp4") # Write video using OpenCV fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for frame in processed_frames: # Convert PIL to OpenCV format frame_cv = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR) out.write(frame_cv) out.release() logger.info(f"Video processing completed. Output saved to: {output_path}") return output_path else: logger.error("No frames were processed successfully") return None except Exception as e: logger.error(f"Error in video processing: {e}") return None