File size: 3,412 Bytes
d33e75e
 
 
 
 
a571565
 
 
 
 
 
 
 
 
0b67fec
d33e75e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b67fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d33e75e
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
VideoMaMa Inference Wrapper
Handles video matting with mask conditioning
"""

import os
import torch

# CRITICAL: Redirect cache to temporary storage
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
os.environ['HUB_DIR'] = '/tmp/torch_hub'
os.environ['TMPDIR'] = '/tmp'
torch.hub.set_dir('/tmp/torch_hub')

import os
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from typing import List
import tqdm

from pipeline_svd_mask import VideoInferencePipeline


def videomama(pipeline, frames_np, mask_frames_np):
    """
    Run VideoMaMa inference on video frames with mask conditioning
    
    Args:
        pipeline: VideoInferencePipeline instance
        frames_np: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
        mask_frames_np: List of numpy arrays, [(H,W)]*n, uint8 grayscale masks
        
    Returns:
        output_frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB outputs
    """
    # Convert numpy arrays to PIL Images
    frames_pil = [Image.fromarray(f) for f in frames_np]
    mask_frames_pil = [Image.fromarray(m, mode='L') for m in mask_frames_np]
    
    # Resize to model input size
    target_width, target_height = 1024, 576
    frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR) 
                     for f in frames_pil]
    masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR) 
                    for m in mask_frames_pil]
    
    # Run inference
    print(f"Running VideoMaMa inference on {len(frames_resized)} frames...")
    output_frames_pil = pipeline.run(
        cond_frames=frames_resized,
        mask_frames=masks_resized,
        seed=42,
        mask_cond_mode="vae"
    )
    
    # Resize back to original resolution
    original_size = frames_pil[0].size
    output_frames_resized = [f.resize(original_size, Image.Resampling.BILINEAR) 
                            for f in output_frames_pil]
    
    # Convert back to numpy arrays
    output_frames_np = [np.array(f) for f in output_frames_resized]
    
    return output_frames_np


def load_videomama_pipeline(device="cuda"):
    """
    Load VideoMaMa pipeline with pretrained weights
    
    Args:
        device: Device to run on
        
    Returns:
        VideoInferencePipeline instance
    """
    # Use relative paths for Hugging Face Space
    # Checkpoints should be downloaded via download_checkpoints.sh
    base_model_path = os.path.join("checkpoints", "stable-video-diffusion-img2vid-xt")
    unet_checkpoint_path = os.path.join("checkpoints", "videomama")
    
    # Check if checkpoints exist
    if not os.path.exists(base_model_path):
        raise FileNotFoundError(
            f"SVD base model not found at {base_model_path}. "
            "Please run download_checkpoints.sh first."
        )
    
    if not os.path.exists(unet_checkpoint_path):
        raise FileNotFoundError(
            f"VideoMaMa checkpoint not found at {unet_checkpoint_path}. "
            "Please run download_checkpoints.sh first."
        )
    
    print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...")
    
    pipeline = VideoInferencePipeline(
        base_model_path=base_model_path,
        unet_checkpoint_path=unet_checkpoint_path,
        weight_dtype=torch.float16,
        device=device
    )
    
    print("VideoMaMa pipeline loaded successfully!")
    
    return pipeline