Spaces:
Build error
Build error
| import os | |
| import torch | |
| import spaces | |
| from diffusers import StableVideoDiffusionPipeline | |
| from transformers import AutoProcessor, AutoModel | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Tuple, Optional, Any | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| MODEL_ID = "Orange-3DV-Team/MoCha" | |
| # 30 minutes for model loading | |
| def load_model() -> Any: | |
| """ | |
| Load the MoCha model for video character replacement. | |
| Returns: | |
| Loaded model instance | |
| """ | |
| try: | |
| logger.info(f"Loading MoCha model: {MODEL_ID}") | |
| # Load the base Stable Video Diffusion model | |
| pipe = StableVideoDiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-video-diffusion-img2vid-xt", | |
| torch_dtype=torch.float16, | |
| variant="fp16" | |
| ) | |
| pipe.to("cuda") | |
| # Load additional components specific to MoCha | |
| try: | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| character_model = AutoModel.from_pretrained(MODEL_ID) | |
| character_model.to("cuda") | |
| logger.info("MoCha character model loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Could not load MoCha-specific components: {e}") | |
| processor = None | |
| character_model = None | |
| # Enable memory efficient attention if available | |
| if hasattr(pipe, 'enable_attention_slicing'): | |
| pipe.enable_attention_slicing() | |
| if hasattr(pipe, 'enable_model_cpu_offload'): | |
| pipe.enable_model_cpu_offload() | |
| logger.info("Model loading completed successfully") | |
| return { | |
| 'pipe': pipe, | |
| 'processor': processor, | |
| 'character_model': character_model, | |
| 'device': 'cuda' | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise RuntimeError(f"Failed to load MoCha model: {e}") | |
| # 10 minutes per video processing | |
| def process_video_character_replacement( | |
| model_dict: dict, | |
| reference_image: Image.Image, | |
| video_path: str, | |
| output_dir: str | |
| ) -> Optional[str]: | |
| """ | |
| Process video with character replacement using MoCha model. | |
| Args: | |
| model_dict: Dictionary containing loaded models | |
| reference_image: PIL Image of target character | |
| video_path: Path to source video | |
| output_dir: Directory to save processed video | |
| Returns: | |
| Path to processed video or None if failed | |
| """ | |
| try: | |
| pipe = model_dict['pipe'] | |
| processor = model_dict['processor'] | |
| character_model = model_dict['character_model'] | |
| device = model_dict['device'] | |
| logger.info("Starting video character replacement process") | |
| # Read video and extract frames | |
| cap = cv2.VideoCapture(video_path) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| logger.info(f"Video info: {width}x{height}, {fps} FPS, {total_frames} frames") | |
| # Prepare reference image | |
| if reference_image.mode != 'RGB': | |
| reference_image = reference_image.convert('RGB') | |
| # Resize reference image to match expected input size | |
| reference_image = reference_image.resize((224, 224)) | |
| processed_frames = [] | |
| frame_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_frame = Image.fromarray(frame_rgb) | |
| # Process frame with character replacement | |
| try: | |
| # Use the reference character for replacement | |
| if character_model is not None and processor is not None: | |
| # Process with MoCha's character model | |
| inputs = processor( | |
| images=[reference_image, pil_frame], | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| # Generate character-replaced frame | |
| output = character_model.generate( | |
| **inputs, | |
| num_frames=1, | |
| guidance_scale=7.5, | |
| num_inference_steps=20 | |
| ) | |
| processed_frame = processor.post_process( | |
| output, output_type="pil" | |
| )[0] | |
| else: | |
| # Fallback to stable video diffusion with character guidance | |
| pil_frame = pil_frame.resize((1024, 576)) | |
| video_frames = pipe( | |
| reference_image, | |
| decode_chunk_size=8, | |
| num_frames=14, | |
| guidance_scale=3.0, | |
| num_inference_steps=25, | |
| motion_bucket_id=127, | |
| noise_aug_strength=0.02, | |
| image=reference_image # Use reference for character guidance | |
| ).frames[0] | |
| processed_frame = video_frames[0] | |
| # Resize back to original dimensions | |
| if processed_frame.size != (width, height): | |
| processed_frame = processed_frame.resize((width, height)) | |
| processed_frames.append(processed_frame) | |
| except Exception as e: | |
| logger.warning(f"Error processing frame {frame_count}: {e}") | |
| # Keep original frame if processing fails | |
| processed_frames.append(pil_frame) | |
| frame_count += 1 | |
| if frame_count % 10 == 0: | |
| logger.info(f"Processed {frame_count}/{total_frames} frames") | |
| cap.release() | |
| # Save processed video | |
| if processed_frames: | |
| output_path = os.path.join(output_dir, "character_replaced_video.mp4") | |
| # Write video using OpenCV | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| for frame in processed_frames: | |
| # Convert PIL to OpenCV format | |
| frame_cv = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR) | |
| out.write(frame_cv) | |
| out.release() | |
| logger.info(f"Video processing completed. Output saved to: {output_path}") | |
| return output_path | |
| else: | |
| logger.error("No frames were processed successfully") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error in video processing: {e}") | |
| return None |