|
|
""" |
|
|
VideoMaMa Inference Wrapper |
|
|
Handles video matting with mask conditioning |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
|
|
|
|
|
|
os.environ['TORCH_HOME'] = '/tmp/torch_cache' |
|
|
os.environ['HUB_DIR'] = '/tmp/torch_hub' |
|
|
os.environ['TMPDIR'] = '/tmp' |
|
|
torch.hub.set_dir('/tmp/torch_hub') |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from pathlib import Path |
|
|
from typing import List |
|
|
import tqdm |
|
|
|
|
|
from pipeline_svd_mask import VideoInferencePipeline |
|
|
|
|
|
|
|
|
def videomama(pipeline, frames_np, mask_frames_np): |
|
|
""" |
|
|
Run VideoMaMa inference on video frames with mask conditioning |
|
|
|
|
|
Args: |
|
|
pipeline: VideoInferencePipeline instance |
|
|
frames_np: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames |
|
|
mask_frames_np: List of numpy arrays, [(H,W)]*n, uint8 grayscale masks |
|
|
|
|
|
Returns: |
|
|
output_frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB outputs |
|
|
""" |
|
|
|
|
|
frames_pil = [Image.fromarray(f) for f in frames_np] |
|
|
mask_frames_pil = [Image.fromarray(m, mode='L') for m in mask_frames_np] |
|
|
|
|
|
|
|
|
target_width, target_height = 1024, 576 |
|
|
frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR) |
|
|
for f in frames_pil] |
|
|
masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR) |
|
|
for m in mask_frames_pil] |
|
|
|
|
|
|
|
|
print(f"Running VideoMaMa inference on {len(frames_resized)} frames...") |
|
|
output_frames_pil = pipeline.run( |
|
|
cond_frames=frames_resized, |
|
|
mask_frames=masks_resized, |
|
|
seed=42, |
|
|
mask_cond_mode="vae" |
|
|
) |
|
|
|
|
|
|
|
|
original_size = frames_pil[0].size |
|
|
output_frames_resized = [f.resize(original_size, Image.Resampling.BILINEAR) |
|
|
for f in output_frames_pil] |
|
|
|
|
|
|
|
|
output_frames_np = [np.array(f) for f in output_frames_resized] |
|
|
|
|
|
return output_frames_np |
|
|
|
|
|
|
|
|
def load_videomama_pipeline(device="cuda"): |
|
|
""" |
|
|
Load VideoMaMa pipeline with pretrained weights |
|
|
|
|
|
Args: |
|
|
device: Device to run on |
|
|
|
|
|
Returns: |
|
|
VideoInferencePipeline instance |
|
|
""" |
|
|
|
|
|
|
|
|
base_model_path = os.path.join("checkpoints", "stable-video-diffusion-img2vid-xt") |
|
|
unet_checkpoint_path = os.path.join("checkpoints", "videomama") |
|
|
|
|
|
|
|
|
if not os.path.exists(base_model_path): |
|
|
raise FileNotFoundError( |
|
|
f"SVD base model not found at {base_model_path}. " |
|
|
"Please run download_checkpoints.sh first." |
|
|
) |
|
|
|
|
|
if not os.path.exists(unet_checkpoint_path): |
|
|
raise FileNotFoundError( |
|
|
f"VideoMaMa checkpoint not found at {unet_checkpoint_path}. " |
|
|
"Please run download_checkpoints.sh first." |
|
|
) |
|
|
|
|
|
print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...") |
|
|
|
|
|
pipeline = VideoInferencePipeline( |
|
|
base_model_path=base_model_path, |
|
|
unet_checkpoint_path=unet_checkpoint_path, |
|
|
weight_dtype=torch.float16, |
|
|
device=device |
|
|
) |
|
|
|
|
|
print("VideoMaMa pipeline loaded successfully!") |
|
|
|
|
|
return pipeline |
|
|
|