File size: 3,654 Bytes
d33e75e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
VideoMaMa Inference Wrapper - Hugging Face Space Version
Handles video matting with mask conditioning
"""

import sys
import os
from pathlib import Path

# Add parent directories to path for imports
sys.path.append(str(Path(__file__).parent))
sys.path.append(str(Path(__file__).parent.parent))

import torch
import numpy as np
from PIL import Image
from typing import List

from pipeline_svd_mask import VideoInferencePipeline


def videomama(pipeline, frames_np, mask_frames_np):
    """
    Run VideoMaMa inference on video frames with mask conditioning
    
    Args:
        pipeline: VideoInferencePipeline instance
        frames_np: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
        mask_frames_np: List of numpy arrays, [(H,W)]*n, uint8 grayscale masks
        
    Returns:
        output_frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB outputs
    """
    # Convert numpy arrays to PIL Images
    frames_pil = [Image.fromarray(f) for f in frames_np]
    mask_frames_pil = [Image.fromarray(m, mode='L') for m in mask_frames_np]
    
    # Resize to model input size
    target_width, target_height = 1024, 576
    frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR) 
                     for f in frames_pil]
    masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR) 
                    for m in mask_frames_pil]
    
    # Run inference
    print(f"Running VideoMaMa inference on {len(frames_resized)} frames...")
    output_frames_pil = pipeline.run(
        cond_frames=frames_resized,
        mask_frames=masks_resized,
        seed=42,
        mask_cond_mode="vae"
    )
    
    # Resize back to original resolution
    original_size = frames_pil[0].size
    output_frames_resized = [f.resize(original_size, Image.Resampling.BILINEAR) 
                            for f in output_frames_pil]
    
    # Convert back to numpy arrays
    output_frames_np = [np.array(f) for f in output_frames_resized]
    
    return output_frames_np


def load_videomama_pipeline(base_model_path=None, unet_checkpoint_path=None, device="cuda"):
    """
    Load VideoMaMa pipeline with pretrained weights
    
    Args:
        base_model_path: Path to SVD base model (if None, uses default)
        unet_checkpoint_path: Path to VideoMaMa UNet checkpoint (if None, uses default)
        device: Device to run on
        
    Returns:
        VideoInferencePipeline instance
    """
    # Use provided paths or defaults
    if base_model_path is None:
        base_model_path = "checkpoints/stable-video-diffusion-img2vid-xt"
    
    if unet_checkpoint_path is None:
        unet_checkpoint_path = "checkpoints/videomama"
    
    # Check if paths exist
    if not os.path.exists(base_model_path):
        raise FileNotFoundError(
            f"SVD base model not found at {base_model_path}. "
            f"Please ensure models are downloaded correctly."
        )
    
    if not os.path.exists(unet_checkpoint_path):
        raise FileNotFoundError(
            f"VideoMaMa checkpoint not found at {unet_checkpoint_path}. "
            f"Please upload your VideoMaMa model to Hugging Face Hub and update the download logic."
        )
    
    print(f"Loading VideoMaMa pipeline...")
    print(f"  Base model: {base_model_path}")
    print(f"  UNet checkpoint: {unet_checkpoint_path}")
    
    pipeline = VideoInferencePipeline(
        base_model_path=base_model_path,
        unet_checkpoint_path=unet_checkpoint_path,
        weight_dtype=torch.float16,
        device=device
    )
    
    print("VideoMaMa pipeline loaded successfully!")
    
    return pipeline