VideoMaMa / videomama_wrapper.py
pizb's picture
update
a571565
"""
VideoMaMa Inference Wrapper
Handles video matting with mask conditioning
"""
import os
import torch
# CRITICAL: Redirect cache to temporary storage
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
os.environ['HUB_DIR'] = '/tmp/torch_hub'
os.environ['TMPDIR'] = '/tmp'
torch.hub.set_dir('/tmp/torch_hub')
import os
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from typing import List
import tqdm
from pipeline_svd_mask import VideoInferencePipeline
def videomama(pipeline, frames_np, mask_frames_np):
"""
Run VideoMaMa inference on video frames with mask conditioning
Args:
pipeline: VideoInferencePipeline instance
frames_np: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
mask_frames_np: List of numpy arrays, [(H,W)]*n, uint8 grayscale masks
Returns:
output_frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB outputs
"""
# Convert numpy arrays to PIL Images
frames_pil = [Image.fromarray(f) for f in frames_np]
mask_frames_pil = [Image.fromarray(m, mode='L') for m in mask_frames_np]
# Resize to model input size
target_width, target_height = 1024, 576
frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR)
for f in frames_pil]
masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR)
for m in mask_frames_pil]
# Run inference
print(f"Running VideoMaMa inference on {len(frames_resized)} frames...")
output_frames_pil = pipeline.run(
cond_frames=frames_resized,
mask_frames=masks_resized,
seed=42,
mask_cond_mode="vae"
)
# Resize back to original resolution
original_size = frames_pil[0].size
output_frames_resized = [f.resize(original_size, Image.Resampling.BILINEAR)
for f in output_frames_pil]
# Convert back to numpy arrays
output_frames_np = [np.array(f) for f in output_frames_resized]
return output_frames_np
def load_videomama_pipeline(device="cuda"):
"""
Load VideoMaMa pipeline with pretrained weights
Args:
device: Device to run on
Returns:
VideoInferencePipeline instance
"""
# Use relative paths for Hugging Face Space
# Checkpoints should be downloaded via download_checkpoints.sh
base_model_path = os.path.join("checkpoints", "stable-video-diffusion-img2vid-xt")
unet_checkpoint_path = os.path.join("checkpoints", "videomama")
# Check if checkpoints exist
if not os.path.exists(base_model_path):
raise FileNotFoundError(
f"SVD base model not found at {base_model_path}. "
"Please run download_checkpoints.sh first."
)
if not os.path.exists(unet_checkpoint_path):
raise FileNotFoundError(
f"VideoMaMa checkpoint not found at {unet_checkpoint_path}. "
"Please run download_checkpoints.sh first."
)
print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...")
pipeline = VideoInferencePipeline(
base_model_path=base_model_path,
unet_checkpoint_path=unet_checkpoint_path,
weight_dtype=torch.float16,
device=device
)
print("VideoMaMa pipeline loaded successfully!")
return pipeline