my-demo-694 / models.py
ItsMpilo's picture
Deploy Gradio app with multiple files
a34d96e verified
import os
import torch
import spaces
from diffusers import StableVideoDiffusionPipeline
from transformers import AutoProcessor, AutoModel
import cv2
import numpy as np
from PIL import Image
from typing import Tuple, Optional, Any
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MODEL_ID = "Orange-3DV-Team/MoCha"
@spaces.GPU(duration=1800) # 30 minutes for model loading
def load_model() -> Any:
"""
Load the MoCha model for video character replacement.
Returns:
Loaded model instance
"""
try:
logger.info(f"Loading MoCha model: {MODEL_ID}")
# Load the base Stable Video Diffusion model
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt",
torch_dtype=torch.float16,
variant="fp16"
)
pipe.to("cuda")
# Load additional components specific to MoCha
try:
processor = AutoProcessor.from_pretrained(MODEL_ID)
character_model = AutoModel.from_pretrained(MODEL_ID)
character_model.to("cuda")
logger.info("MoCha character model loaded successfully")
except Exception as e:
logger.warning(f"Could not load MoCha-specific components: {e}")
processor = None
character_model = None
# Enable memory efficient attention if available
if hasattr(pipe, 'enable_attention_slicing'):
pipe.enable_attention_slicing()
if hasattr(pipe, 'enable_model_cpu_offload'):
pipe.enable_model_cpu_offload()
logger.info("Model loading completed successfully")
return {
'pipe': pipe,
'processor': processor,
'character_model': character_model,
'device': 'cuda'
}
except Exception as e:
logger.error(f"Error loading model: {e}")
raise RuntimeError(f"Failed to load MoCha model: {e}")
@spaces.GPU(duration=600) # 10 minutes per video processing
def process_video_character_replacement(
model_dict: dict,
reference_image: Image.Image,
video_path: str,
output_dir: str
) -> Optional[str]:
"""
Process video with character replacement using MoCha model.
Args:
model_dict: Dictionary containing loaded models
reference_image: PIL Image of target character
video_path: Path to source video
output_dir: Directory to save processed video
Returns:
Path to processed video or None if failed
"""
try:
pipe = model_dict['pipe']
processor = model_dict['processor']
character_model = model_dict['character_model']
device = model_dict['device']
logger.info("Starting video character replacement process")
# Read video and extract frames
cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
logger.info(f"Video info: {width}x{height}, {fps} FPS, {total_frames} frames")
# Prepare reference image
if reference_image.mode != 'RGB':
reference_image = reference_image.convert('RGB')
# Resize reference image to match expected input size
reference_image = reference_image.resize((224, 224))
processed_frames = []
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_frame = Image.fromarray(frame_rgb)
# Process frame with character replacement
try:
# Use the reference character for replacement
if character_model is not None and processor is not None:
# Process with MoCha's character model
inputs = processor(
images=[reference_image, pil_frame],
return_tensors="pt"
).to(device)
with torch.no_grad():
# Generate character-replaced frame
output = character_model.generate(
**inputs,
num_frames=1,
guidance_scale=7.5,
num_inference_steps=20
)
processed_frame = processor.post_process(
output, output_type="pil"
)[0]
else:
# Fallback to stable video diffusion with character guidance
pil_frame = pil_frame.resize((1024, 576))
video_frames = pipe(
reference_image,
decode_chunk_size=8,
num_frames=14,
guidance_scale=3.0,
num_inference_steps=25,
motion_bucket_id=127,
noise_aug_strength=0.02,
image=reference_image # Use reference for character guidance
).frames[0]
processed_frame = video_frames[0]
# Resize back to original dimensions
if processed_frame.size != (width, height):
processed_frame = processed_frame.resize((width, height))
processed_frames.append(processed_frame)
except Exception as e:
logger.warning(f"Error processing frame {frame_count}: {e}")
# Keep original frame if processing fails
processed_frames.append(pil_frame)
frame_count += 1
if frame_count % 10 == 0:
logger.info(f"Processed {frame_count}/{total_frames} frames")
cap.release()
# Save processed video
if processed_frames:
output_path = os.path.join(output_dir, "character_replaced_video.mp4")
# Write video using OpenCV
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for frame in processed_frames:
# Convert PIL to OpenCV format
frame_cv = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
out.write(frame_cv)
out.release()
logger.info(f"Video processing completed. Output saved to: {output_path}")
return output_path
else:
logger.error("No frames were processed successfully")
return None
except Exception as e:
logger.error(f"Error in video processing: {e}")
return None