""" VideoMaMa Inference Wrapper Handles video matting with mask conditioning """ import os import torch # CRITICAL: Redirect cache to temporary storage os.environ['TORCH_HOME'] = '/tmp/torch_cache' os.environ['HUB_DIR'] = '/tmp/torch_hub' os.environ['TMPDIR'] = '/tmp' torch.hub.set_dir('/tmp/torch_hub') import os import torch import numpy as np from PIL import Image from pathlib import Path from typing import List import tqdm from pipeline_svd_mask import VideoInferencePipeline def videomama(pipeline, frames_np, mask_frames_np): """ Run VideoMaMa inference on video frames with mask conditioning Args: pipeline: VideoInferencePipeline instance frames_np: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames mask_frames_np: List of numpy arrays, [(H,W)]*n, uint8 grayscale masks Returns: output_frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB outputs """ # Convert numpy arrays to PIL Images frames_pil = [Image.fromarray(f) for f in frames_np] mask_frames_pil = [Image.fromarray(m, mode='L') for m in mask_frames_np] # Resize to model input size target_width, target_height = 1024, 576 frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR) for f in frames_pil] masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR) for m in mask_frames_pil] # Run inference print(f"Running VideoMaMa inference on {len(frames_resized)} frames...") output_frames_pil = pipeline.run( cond_frames=frames_resized, mask_frames=masks_resized, seed=42, mask_cond_mode="vae" ) # Resize back to original resolution original_size = frames_pil[0].size output_frames_resized = [f.resize(original_size, Image.Resampling.BILINEAR) for f in output_frames_pil] # Convert back to numpy arrays output_frames_np = [np.array(f) for f in output_frames_resized] return output_frames_np def load_videomama_pipeline(device="cuda"): """ Load VideoMaMa pipeline with pretrained weights Args: device: Device to run on Returns: VideoInferencePipeline instance """ # Use relative paths for Hugging Face Space # Checkpoints should be downloaded via download_checkpoints.sh base_model_path = os.path.join("checkpoints", "stable-video-diffusion-img2vid-xt") unet_checkpoint_path = os.path.join("checkpoints", "videomama") # Check if checkpoints exist if not os.path.exists(base_model_path): raise FileNotFoundError( f"SVD base model not found at {base_model_path}. " "Please run download_checkpoints.sh first." ) if not os.path.exists(unet_checkpoint_path): raise FileNotFoundError( f"VideoMaMa checkpoint not found at {unet_checkpoint_path}. " "Please run download_checkpoints.sh first." ) print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...") pipeline = VideoInferencePipeline( base_model_path=base_model_path, unet_checkpoint_path=unet_checkpoint_path, weight_dtype=torch.float16, device=device ) print("VideoMaMa pipeline loaded successfully!") return pipeline