RAM2118 commited on
Commit
8d777b1
·
verified ·
1 Parent(s): 2255092

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ env/
26
+ ENV/
27
+ .venv
28
+
29
+ # IDE
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+ *~
35
+
36
+ # Gradio
37
+ flagged/
38
+
39
+ # Temporary files
40
+ *.tmp
41
+ temp/
42
+ temp_*/
43
+ *.log
44
+
45
+ # Model checkpoints (download separately)
46
+ checkpoints/*.pt
47
+ checkpoints/*.pth
48
+ checkpoints/*.safetensors
49
+ checkpoints/*.bin
50
+
51
+ # Videos
52
+ samples/*.mp4
53
+ samples/*.avi
54
+ samples/*.mov
55
+ *.mp4
56
+ *.avi
57
+ *.mov
58
+
59
+ # OS
60
+ .DS_Store
61
+ Thumbs.db
62
+ *.bak
63
+
64
+ # Jupyter
65
+ .ipynb_checkpoints/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,12 +1,91 @@
1
  ---
2
- title: VideoMaMa Custom
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.6.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: VideoMaMa - Video Matting with Mask Guidance
3
+ emoji: 🎬
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ # 🎬 VideoMaMa: Video Matting with Mask Guidance
14
+
15
+ An interactive demo for high-quality video matting using sparse mask guidance. This demo combines SAM2 for automatic object tracking with our VideoMaMa model for generating alpha mattes.
16
+
17
+ ## 🌟 Features
18
+
19
+ - **Single-Click Object Selection**: Simply click on the object you want to extract in the first frame
20
+ - **Automatic Tracking**: SAM2 automatically tracks your selected object through all frames
21
+ - **High-Quality Matting**: VideoMaMa generates smooth, temporally-consistent alpha mattes
22
+ - **Flexible Input**: Upload your own video or try our provided samples
23
+ - **Customizable**: Adjust augmentation settings for different scenarios
24
+
25
+ ## 🚀 How to Use
26
+
27
+ 1. **Upload a video** or **select from samples**
28
+ 2. **Click on the object** you want to extract in the first frame (displayed in the interface)
29
+ 3. Optionally adjust **augmentation settings** in the advanced options
30
+ 4. Click **"Generate Matting"** and wait for processing
31
+ 5. View your results: output video, comparison images, and mask track
32
+
33
+
34
+ ## 🔧 Installation (Local Setup)
35
+
36
+ If you want to run this demo locally:
37
+
38
+ ```bash
39
+ # Install dependencies
40
+ pip install -r requirements.txt
41
+
42
+ # Add sample videos to samples/ directory (optional)
43
+
44
+ # Run the demo
45
+ python app.py
46
+ ```
47
+
48
+ ## 🎯 Tips for Best Results
49
+
50
+ - **Click Precisely**: Click on the center of the object you want to extract
51
+ - **Clear Objects**: Works best with distinct foreground objects
52
+ - **Video Length**: For faster processing, use shorter videos (< 5 seconds)
53
+ - **Augmentations**:
54
+ - Use "polygon" for cleaner geometric masks
55
+ - Enable temporal augmentation for challenging videos
56
+ - Try "bounding box" for very simple selections
57
+
58
+ ## 📚 Technical Details
59
+
60
+ ### Model Architecture
61
+ - **Base Model**: Stable Video Diffusion (SVD-XT)
62
+ - **Conditioning**: RGB frames + VAE-encoded masks
63
+ - **UNet**: Fine-tuned with additional mask conditioning channels
64
+ - **Processing**: Chunked inference (16 frames per chunk)
65
+
66
+ ### SAM2 Integration
67
+ - Uses SAM2 video predictor for mask tracking
68
+ - Propagates mask from single click point through entire video
69
+ - Generates temporally consistent segmentation masks
70
+
71
+ ## 🤝 Contributing
72
+
73
+ If you encounter issues or have suggestions:
74
+ 1. Check that all model checkpoints are correctly placed
75
+ 2. Ensure your GPU has sufficient VRAM
76
+ 3. Try reducing video length or resolution for testing
77
+
78
+
79
+ ## 🙏 Acknowledgments
80
+
81
+ - **SAM2**: Meta AI's Segment Anything 2
82
+ - **Stable Video Diffusion**: Stability AI's video generation model
83
+ - **Gradio**: For the amazing UI framework
84
+
85
+ ## 📧 Contact
86
+
87
+ For questions or issues, please open an issue on our GitHub repository.
88
+
89
+ ---
90
+
91
+ **Note**: This demo is for research purposes. Processing times may vary based on video length and available compute resources.
app.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VideoMaMa Gradio Demo
3
+ Interactive video matting with SAM2 mask tracking
4
+ """
5
+
6
+ import sys
7
+ sys.path.append("../")
8
+ sys.path.append("../../")
9
+
10
+ import os
11
+ import json
12
+ import time
13
+ import cv2
14
+ import torch
15
+ import numpy as np
16
+ import gradio as gr
17
+ from PIL import Image
18
+ from pathlib import Path
19
+
20
+ from sam2_wrapper import load_sam2_tracker
21
+ from videomama_wrapper import load_videomama_pipeline, videomama
22
+ from tools.painter import mask_painter, point_painter
23
+
24
+ import warnings
25
+ warnings.filterwarnings("ignore")
26
+
27
+ # Global models
28
+ sam2_tracker = None
29
+ videomama_pipeline = None
30
+
31
+ # Constants
32
+ MASK_COLOR = 3
33
+ MASK_ALPHA = 0.7
34
+ CONTOUR_COLOR = 1
35
+ CONTOUR_WIDTH = 5
36
+ POINT_COLOR_POS = 8 # Positive points - orange
37
+ POINT_COLOR_NEG = 1 # Negative points - red
38
+ POINT_ALPHA = 0.9
39
+ POINT_RADIUS = 15
40
+
41
+ def initialize_models():
42
+ """Initialize SAM2 and VideoMaMa models"""
43
+ global sam2_tracker, videomama_pipeline
44
+
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ print(f"Using device: {device}")
47
+
48
+ # Load SAM2
49
+ sam2_tracker = load_sam2_tracker(device=device)
50
+
51
+ # Load VideoMaMa
52
+ videomama_pipeline = load_videomama_pipeline(device=device)
53
+
54
+ print("All models initialized successfully!")
55
+
56
+
57
+ def extract_frames_from_video(video_path, max_frames=24):
58
+ """
59
+ Extract frames from video file
60
+
61
+ Args:
62
+ video_path: Path to video file
63
+ max_frames: Maximum number of frames to extract (default: 24)
64
+
65
+ Returns:
66
+ frames: List of numpy arrays (H,W,3), uint8 RGB
67
+ adjusted_fps: Adjusted FPS for output video to maintain normal playback speed
68
+ """
69
+ cap = cv2.VideoCapture(video_path)
70
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
71
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
72
+
73
+ # Read all frames first
74
+ all_frames = []
75
+ while cap.isOpened():
76
+ ret, frame = cap.read()
77
+ if not ret:
78
+ break
79
+ # Convert BGR to RGB
80
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
81
+ all_frames.append(frame_rgb)
82
+
83
+ cap.release()
84
+
85
+ # If video has more frames than max_frames, randomly sample
86
+ if len(all_frames) > max_frames:
87
+ print(f"Video has {len(all_frames)} frames, randomly sampling {max_frames} frames...")
88
+ # Sort indices to maintain temporal order
89
+ sampled_indices = sorted(np.random.choice(len(all_frames), max_frames, replace=False))
90
+ frames = [all_frames[i] for i in sampled_indices]
91
+ print(f"Sampled frame indices: {sampled_indices}")
92
+
93
+ # Adjust FPS to maintain normal playback speed
94
+ # If we sampled N frames from M total frames, adjust FPS proportionally
95
+ adjusted_fps = original_fps * (len(frames) / len(all_frames))
96
+ else:
97
+ frames = all_frames
98
+ adjusted_fps = original_fps
99
+ print(f"Video has {len(frames)} frames (≤ {max_frames}), using all frames")
100
+
101
+ print(f"Using {len(frames)} frames from video (Original FPS: {original_fps:.2f}, Adjusted FPS: {adjusted_fps:.2f})")
102
+
103
+ return frames, adjusted_fps
104
+
105
+
106
+ def get_prompt(click_state, click_input):
107
+ """
108
+ Convert click input to prompt format
109
+
110
+ Args:
111
+ click_state: [[points], [labels]]
112
+ click_input: JSON string "[[x, y, label]]"
113
+
114
+ Returns:
115
+ Updated click_state
116
+ """
117
+ inputs = json.loads(click_input)
118
+ points = click_state[0]
119
+ labels = click_state[1]
120
+
121
+ for input_item in inputs:
122
+ points.append(input_item[:2])
123
+ labels.append(input_item[2])
124
+
125
+ click_state[0] = points
126
+ click_state[1] = labels
127
+
128
+ return click_state
129
+
130
+
131
+ def load_video(video_input, video_state, num_frames):
132
+ """
133
+ Load video and extract first frame for mask generation
134
+ """
135
+ # Clean up old output files if they exist
136
+ if video_state is not None and "output_paths" in video_state:
137
+ cleanup_old_videos(video_state["output_paths"])
138
+
139
+ if video_input is None:
140
+ return video_state, None, \
141
+ gr.update(visible=False), gr.update(visible=False), \
142
+ gr.update(visible=False), gr.update(visible=False)
143
+
144
+ # Extract frames with user-specified number
145
+ frames, fps = extract_frames_from_video(video_input, max_frames=num_frames)
146
+
147
+ if len(frames) == 0:
148
+ return video_state, None, \
149
+ gr.update(visible=False), gr.update(visible=False), \
150
+ gr.update(visible=False), gr.update(visible=False)
151
+
152
+ # Initialize video state
153
+ video_state = {
154
+ "frames": frames,
155
+ "fps": fps,
156
+ "first_frame_mask": None,
157
+ "masks": None,
158
+ }
159
+
160
+ first_frame_pil = Image.fromarray(frames[0])
161
+
162
+ return video_state, first_frame_pil, \
163
+ gr.update(visible=True), gr.update(visible=True), \
164
+ gr.update(visible=True), gr.update(visible=False)
165
+
166
+
167
+ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
168
+ """
169
+ Add click and update mask on first frame
170
+
171
+ Args:
172
+ video_state: Dictionary with video data
173
+ point_prompt: "Positive" or "Negative"
174
+ click_state: [[points], [labels]]
175
+ evt: Gradio SelectData event with click coordinates
176
+ """
177
+ if video_state is None or "frames" not in video_state:
178
+ return None, video_state, click_state
179
+
180
+ # Add new click
181
+ x, y = evt.index[0], evt.index[1]
182
+ label = 1 if point_prompt == "Positive" else 0
183
+
184
+ click_state[0].append([x, y])
185
+ click_state[1].append(label)
186
+
187
+ print(f"Added {point_prompt} click at ({x}, {y}). Total clicks: {len(click_state[0])}")
188
+
189
+ # Generate mask with SAM2
190
+ first_frame = video_state["frames"][0]
191
+ mask = sam2_tracker.get_first_frame_mask(
192
+ frame=first_frame,
193
+ points=click_state[0],
194
+ labels=click_state[1]
195
+ )
196
+
197
+ # Store mask in video state
198
+ video_state["first_frame_mask"] = mask
199
+
200
+ # Visualize mask and points
201
+ painted_image = mask_painter(
202
+ first_frame.copy(),
203
+ mask,
204
+ MASK_COLOR,
205
+ MASK_ALPHA,
206
+ CONTOUR_COLOR,
207
+ CONTOUR_WIDTH
208
+ )
209
+
210
+ # Paint positive points
211
+ positive_points = np.array([click_state[0][i] for i in range(len(click_state[0]))
212
+ if click_state[1][i] == 1])
213
+ if len(positive_points) > 0:
214
+ painted_image = point_painter(
215
+ painted_image,
216
+ positive_points,
217
+ POINT_COLOR_POS,
218
+ POINT_ALPHA,
219
+ POINT_RADIUS,
220
+ CONTOUR_COLOR,
221
+ CONTOUR_WIDTH
222
+ )
223
+
224
+ # Paint negative points
225
+ negative_points = np.array([click_state[0][i] for i in range(len(click_state[0]))
226
+ if click_state[1][i] == 0])
227
+ if len(negative_points) > 0:
228
+ painted_image = point_painter(
229
+ painted_image,
230
+ negative_points,
231
+ POINT_COLOR_NEG,
232
+ POINT_ALPHA,
233
+ POINT_RADIUS,
234
+ CONTOUR_COLOR,
235
+ CONTOUR_WIDTH
236
+ )
237
+
238
+ painted_pil = Image.fromarray(painted_image)
239
+
240
+ return painted_pil, video_state, click_state
241
+
242
+
243
+ def clear_clicks(video_state, click_state):
244
+ """Clear all clicks and reset to original first frame"""
245
+ click_state = [[], []]
246
+
247
+ if video_state is not None and "frames" in video_state:
248
+ first_frame = video_state["frames"][0]
249
+ video_state["first_frame_mask"] = None
250
+ return Image.fromarray(first_frame), video_state, click_state
251
+
252
+ return None, video_state, click_state
253
+
254
+
255
+ def propagate_masks(video_state, click_state):
256
+ """
257
+ Propagate first frame mask through entire video using SAM2
258
+ """
259
+ if video_state is None or "frames" not in video_state:
260
+ return video_state, "No video loaded", gr.update(visible=False)
261
+
262
+ if len(click_state[0]) == 0:
263
+ return video_state, "⚠️ Please add at least one point first", gr.update(visible=False)
264
+
265
+ frames = video_state["frames"]
266
+
267
+ # Track through video
268
+ print(f"Tracking object through {len(frames)} frames...")
269
+ masks = sam2_tracker.track_video(
270
+ frames=frames,
271
+ points=click_state[0],
272
+ labels=click_state[1]
273
+ )
274
+
275
+ video_state["masks"] = masks
276
+
277
+ status_msg = f"✓ Generated {len(masks)} masks. Ready to run VideoMaMa!"
278
+
279
+ return video_state, status_msg, gr.update(visible=True)
280
+
281
+
282
+ def run_videomama_with_sam2(video_state, click_state):
283
+ """
284
+ Run SAM2 propagation and VideoMaMa inference together
285
+ """
286
+ if video_state is None or "frames" not in video_state:
287
+ return video_state, None, None, None, "⚠️ No video loaded"
288
+
289
+ if len(click_state[0]) == 0:
290
+ return video_state, None, None, None, "⚠️ Please add at least one point first"
291
+
292
+ frames = video_state["frames"]
293
+
294
+ # Step 1: Track through video with SAM2
295
+ print(f"🎯 Tracking object through {len(frames)} frames with SAM2...")
296
+ masks = sam2_tracker.track_video(
297
+ frames=frames,
298
+ points=click_state[0],
299
+ labels=click_state[1]
300
+ )
301
+
302
+ video_state["masks"] = masks
303
+ print(f"✓ Generated {len(masks)} masks")
304
+
305
+ # Step 2: Run VideoMaMa
306
+ print(f"🎨 Running VideoMaMa on {len(frames)} frames...")
307
+ output_frames = videomama(videomama_pipeline, frames, masks)
308
+
309
+ # Save output videos
310
+ output_dir = Path("outputs")
311
+ output_dir.mkdir(exist_ok=True)
312
+
313
+ timestamp = int(time.time())
314
+ output_video_path = output_dir / f"output_{timestamp}.mp4"
315
+ mask_video_path = output_dir / f"masks_{timestamp}.mp4"
316
+ greenscreen_path = output_dir / f"greenscreen_{timestamp}.mp4"
317
+
318
+ # Save matting result
319
+ save_video(output_frames, output_video_path, video_state["fps"])
320
+
321
+ # Save mask video (for visualization)
322
+ mask_frames_rgb = [np.stack([m, m, m], axis=-1) for m in masks]
323
+ save_video(mask_frames_rgb, mask_video_path, video_state["fps"])
324
+
325
+ # Create greenscreen composite: RGB * VideoMaMa_alpha + green * (1 - VideoMaMa_alpha)
326
+ # VideoMaMa output_frames already contain the alpha matte result
327
+ greenscreen_frames = []
328
+ for orig_frame, output_frame in zip(frames, output_frames):
329
+ # Extract alpha matte from VideoMaMa output
330
+ # VideoMaMa outputs matted foreground, we use its intensity as alpha
331
+ gray = cv2.cvtColor(output_frame, cv2.COLOR_RGB2GRAY)
332
+ alpha = np.clip(gray.astype(np.float32) / 255.0, 0, 1)
333
+ alpha_3ch = np.stack([alpha, alpha, alpha], axis=-1)
334
+
335
+ # Create green background
336
+ green_bg = np.zeros_like(orig_frame)
337
+ green_bg[:, :] = [156, 251, 165] # Green screen color
338
+
339
+ # Composite: original_RGB * alpha + green * (1 - alpha)
340
+ composite = (orig_frame.astype(np.float32) * alpha_3ch +
341
+ green_bg.astype(np.float32) * (1 - alpha_3ch)).astype(np.uint8)
342
+ greenscreen_frames.append(composite)
343
+
344
+ save_video(greenscreen_frames, greenscreen_path, video_state["fps"])
345
+
346
+ status_msg = f"✓ Complete! Generated {len(output_frames)} frames."
347
+
348
+ # Store paths for cleanup later
349
+ video_state["output_paths"] = [str(output_video_path), str(mask_video_path), str(greenscreen_path)]
350
+
351
+ return video_state, str(output_video_path), str(mask_video_path), str(greenscreen_path), status_msg
352
+
353
+
354
+ def save_video(frames, output_path, fps):
355
+ """Save frames as video file"""
356
+ if len(frames) == 0:
357
+ return
358
+
359
+ height, width = frames[0].shape[:2]
360
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
361
+ out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
362
+
363
+ for frame in frames:
364
+ if len(frame.shape) == 2: # Grayscale
365
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
366
+ else: # RGB
367
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
368
+ out.write(frame)
369
+
370
+ out.release()
371
+ print(f"Saved video to {output_path}")
372
+
373
+
374
+ def cleanup_old_videos(video_paths):
375
+ """Remove old output videos to save storage space"""
376
+ if video_paths is None:
377
+ return
378
+
379
+ for path in video_paths:
380
+ try:
381
+ if os.path.exists(path):
382
+ os.remove(path)
383
+ print(f"Cleaned up: {path}")
384
+ except Exception as e:
385
+ print(f"Failed to remove {path}: {e}")
386
+
387
+
388
+ def cleanup_old_outputs(max_age_minutes=30):
389
+ """
390
+ Remove output files older than max_age_minutes to prevent storage overflow
391
+ This runs periodically to clean up abandoned files
392
+ """
393
+ output_dir = Path("outputs")
394
+ if not output_dir.exists():
395
+ return
396
+
397
+ current_time = time.time()
398
+ max_age_seconds = max_age_minutes * 60
399
+
400
+ for file_path in output_dir.glob("*.mp4"):
401
+ try:
402
+ file_age = current_time - file_path.stat().st_mtime
403
+ if file_age > max_age_seconds:
404
+ file_path.unlink()
405
+ print(f"Cleaned up old file: {file_path} (age: {file_age/60:.1f} minutes)")
406
+ except Exception as e:
407
+ print(f"Failed to clean up {file_path}: {e}")
408
+
409
+
410
+ def restart():
411
+ """Reset all states"""
412
+ return None, [[], []], None, \
413
+ gr.update(visible=False), gr.update(visible=False), \
414
+ gr.update(visible=False), None, None, None, ""
415
+
416
+
417
+ # CSS styling
418
+ custom_css = """
419
+ .gradio-container {width: 90% !important; margin: 0 auto;}
420
+ .title-text {text-align: center; font-size: 48px; font-weight: bold;
421
+ background: linear-gradient(to right, #8b5cf6, #10b981);
422
+ -webkit-background-clip: text; -webkit-text-fill-color: transparent;}
423
+ .description-text {text-align: center; font-size: 18px; margin: 20px 0;}
424
+ button {border-radius: 8px !important;}
425
+ .green_button {background-color: #10b981 !important; color: white !important;}
426
+ .red_button {background-color: #ef4444 !important; color: white !important;}
427
+ .run_matting_button {
428
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 50%, #f093fb 100%) !important;
429
+ color: white !important;
430
+ font-weight: bold !important;
431
+ font-size: 18px !important;
432
+ padding: 20px !important;
433
+ box-shadow: 0 4px 15px 0 rgba(102, 126, 234, 0.75) !important;
434
+ border: none !important;
435
+ }
436
+ .run_matting_button:hover {
437
+ background: linear-gradient(135deg, #764ba2 0%, #667eea 50%, #f093fb 100%) !important;
438
+ box-shadow: 0 6px 20px 0 rgba(102, 126, 234, 0.9) !important;
439
+ transform: translateY(-2px) !important;
440
+ }
441
+ """
442
+
443
+ # Build Gradio interface
444
+ with gr.Blocks(css=custom_css, title="VideoMaMa Demo") as demo:
445
+ gr.HTML('<div class="title-text">VideoMaMa Interactive Demo</div>')
446
+ gr.Markdown(
447
+ '<div class="description-text">🎬 Upload a video → 🖱️ Click to mark object → ✅ Generate masks → 🎨 Run VideoMaMa</div>'
448
+ )
449
+ gr.Markdown(
450
+ '<div style="text-align: center; color: #6b7280; font-size: 14px; margin-top: -10px;">Note: VideoMaMa processes the selected number of frames (1-50). Longer videos will be randomly sampled.</div>'
451
+ )
452
+
453
+ # State variables
454
+ video_state = gr.State(None)
455
+ click_state = gr.State([[], []]) # [[points], [labels]]
456
+
457
+ with gr.Row():
458
+ with gr.Column(scale=1):
459
+ gr.Markdown("### Step 1: Upload Video")
460
+ video_input = gr.Video(label="Input Video")
461
+ num_frames_slider = gr.Slider(
462
+ minimum=1,
463
+ maximum=50,
464
+ value=24,
465
+ step=1,
466
+ label="Number of Frames to Process",
467
+ info="VideoMaMa will process only this many frames. More frames = better quality but slower."
468
+ )
469
+ load_button = gr.Button("📁 Load Video", variant="primary")
470
+
471
+ gr.Markdown("### Step 2: Mark Object")
472
+ point_prompt = gr.Radio(
473
+ choices=["Positive", "Negative"],
474
+ value="Positive",
475
+ label="Click Type",
476
+ info="Positive: object, Negative: background",
477
+ visible=False
478
+ )
479
+ clear_button = gr.Button("🗑️ Clear Clicks", visible=False)
480
+
481
+ with gr.Column(scale=1):
482
+ gr.Markdown("### First Frame (Click to Add Points)")
483
+ first_frame_display = gr.Image(
484
+ label="First Frame",
485
+ type="pil",
486
+ interactive=True
487
+ )
488
+ run_button = gr.Button("🚀 Run Matting", visible=False, elem_classes="run_matting_button", size="lg")
489
+
490
+ status_text = gr.Textbox(label="Status", value="", interactive=False, visible=False)
491
+
492
+ gr.Markdown("### Outputs")
493
+ with gr.Row():
494
+ with gr.Column():
495
+ output_video = gr.Video(label="Matting Result", autoplay=True)
496
+ with gr.Column():
497
+ greenscreen_video = gr.Video(label="Greenscreen Composite", autoplay=True)
498
+ with gr.Column():
499
+ mask_video = gr.Video(label="Mask Track", autoplay=True)
500
+
501
+ # Event handlers
502
+ load_button.click(
503
+ fn=load_video,
504
+ inputs=[video_input, video_state, num_frames_slider],
505
+ outputs=[video_state, first_frame_display,
506
+ point_prompt, clear_button, run_button, status_text]
507
+ )
508
+
509
+ first_frame_display.select(
510
+ fn=sam_refine,
511
+ inputs=[video_state, point_prompt, click_state],
512
+ outputs=[first_frame_display, video_state, click_state]
513
+ )
514
+
515
+ clear_button.click(
516
+ fn=clear_clicks,
517
+ inputs=[video_state, click_state],
518
+ outputs=[first_frame_display, video_state, click_state]
519
+ )
520
+
521
+ run_button.click(
522
+ fn=run_videomama_with_sam2,
523
+ inputs=[video_state, click_state],
524
+ outputs=[video_state, output_video, mask_video, greenscreen_video, status_text]
525
+ )
526
+
527
+ video_input.change(
528
+ fn=restart,
529
+ inputs=[],
530
+ outputs=[video_state, click_state, first_frame_display,
531
+ point_prompt, clear_button, run_button,
532
+ output_video, mask_video, greenscreen_video, status_text]
533
+ )
534
+
535
+ # Examples
536
+ gr.Markdown("---\n### 📦 Example Videos")
537
+ example_dir = Path("samples")
538
+ if example_dir.exists():
539
+ examples = [str(p) for p in sorted(example_dir.glob("*.mp4"))]
540
+ if examples:
541
+ gr.Examples(examples=examples, inputs=[video_input])
542
+
543
+
544
+ if __name__ == "__main__":
545
+ print("=" * 60)
546
+ print("VideoMaMa Interactive Demo")
547
+ print("=" * 60)
548
+
549
+ # Clean up old output files on startup
550
+ cleanup_old_outputs(max_age_minutes=30)
551
+
552
+ # Initialize models
553
+ initialize_models()
554
+
555
+ # Launch demo
556
+ demo.queue()
557
+ demo.launch(
558
+ server_name="127.0.0.1",
559
+ server_port=7860,
560
+ share=True
561
+ )
download_checkpoints.sh ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Download model checkpoints for VideoMaMa demo
3
+
4
+ set -e
5
+
6
+ echo "🔽 Downloading model checkpoints for VideoMaMa demo..."
7
+ echo ""
8
+
9
+ # Create checkpoints directory
10
+ echo "Creating checkpoints directory..."
11
+ mkdir -p checkpoints
12
+ echo "✓ Directory created"
13
+ echo ""
14
+
15
+ # Download SAM2 checkpoint
16
+ echo "Downloading SAM2 checkpoint..."
17
+ echo "URL: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
18
+ echo "This may take a few minutes (file size: ~900MB)..."
19
+
20
+ if command -v wget &> /dev/null; then
21
+ wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt \
22
+ -O checkpoints/sam2/sam2_hiera_large.pt
23
+ elif command -v curl &> /dev/null; then
24
+ curl -L https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt \
25
+ -o checkpoints/sam2/sam2_hiera_large.pt
26
+ else
27
+ echo "❌ Error: Neither wget nor curl is available. Please install one of them."
28
+ exit 1
29
+ fi
30
+
31
+ echo "✓ SAM2 checkpoint downloaded successfully"
32
+ echo ""
33
+
34
+ # Check if VideoMaMa checkpoint exists
35
+ echo "Checking VideoMaMa checkpoint..."
36
+ if [ -d "checkpoints/VideoMaMa" ]; then
37
+ if [ -f "checkpoints/VideoMaMa/config.json" ] && \
38
+ { [ -f "checkpoints/VideoMaMa/diffusion_pytorch_model.safetensors" ] || \
39
+ [ -f "checkpoints/VideoMaMa/diffusion_pytorch_model.bin" ]; }; then
40
+ echo "✓ VideoMaMa checkpoint already exists"
41
+ else
42
+ echo "⚠️ VideoMaMa checkpoint directory exists but is incomplete"
43
+ echo " Please add the following files to checkpoints/VideoMaMa/:"
44
+ echo " - config.json"
45
+ echo " - diffusion_pytorch_model.safetensors (or .bin)"
46
+ fi
47
+ else
48
+ echo "⚠️ VideoMaMa checkpoint not found"
49
+ echo ""
50
+ echo "📝 Manual step required:"
51
+ echo " 1. Create directory: checkpoints/VideoMaMa/"
52
+ echo " 2. Copy your trained VideoMaMa checkpoint files:"
53
+ echo " - config.json"
54
+ echo " - diffusion_pytorch_model.safetensors (or .bin)"
55
+ echo ""
56
+ echo " Example:"
57
+ echo " mkdir -p checkpoints/VideoMaMa"
58
+ echo " cp /path/to/your/checkpoint/* checkpoints/VideoMaMa/"
59
+ fi
60
+
61
+ echo ""
62
+ echo "="*70
63
+ echo "✨ Checkpoint download complete!"
64
+ echo "="*70
65
+ echo ""
66
+ echo "Next steps:"
67
+ echo "1. Verify checkpoints are in place:"
68
+ echo " python test_setup.py"
69
+ echo ""
70
+ echo "2. (Optional) Add sample videos:"
71
+ echo " mkdir -p samples"
72
+ echo " cp your_sample.mp4 samples/"
73
+ echo ""
74
+ echo "3. Test locally:"
75
+ echo " python app.py"
76
+ echo ""
77
+ echo "4. Deploy to Hugging Face Space"
78
+ echo ""
enhanced_ui.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ def create_enhanced_ui():
7
+ with gr.Blocks() as demo:
8
+ gr.Markdown("# VideoMaMa - Enhanced Segmentation")
9
+
10
+ with gr.Row():
11
+ with gr.Column():
12
+ video_input = gr.Video(label="Upload Video")
13
+
14
+ # Segmentation method selector
15
+ seg_method = gr.Radio(
16
+ ["Click Points", "Brush/Draw", "Text Prompt"],
17
+ label="Segmentation Method",
18
+ value="Click Points"
19
+ )
20
+
21
+ # Text prompt input (shown when Text Prompt selected)
22
+ text_prompt = gr.Textbox(
23
+ label="Text Prompt",
24
+ placeholder="e.g., 'person', 'piano', 'cat'",
25
+ visible=False
26
+ )
27
+
28
+ # Image editor with multiple tools
29
+ image_editor = gr.Image(
30
+ label="Select/Draw Object",
31
+ tool="sketch", # Brush tool
32
+ brush_radius=15,
33
+ brush_color="#FF0000"
34
+ )
35
+
36
+ process_btn = gr.Button("Process Video", variant="primary")
37
+
38
+ with gr.Column():
39
+ output_video = gr.Video(label="Result")
40
+ mask_preview = gr.Image(label="Mask Preview")
41
+
42
+ # Toggle text input visibility based on method
43
+ def update_visibility(method):
44
+ return gr.update(visible=(method == "Text Prompt"))
45
+
46
+ seg_method.change(
47
+ update_visibility,
48
+ inputs=[seg_method],
49
+ outputs=[text_prompt]
50
+ )
51
+
52
+ process_btn.click(
53
+ process_video_enhanced,
54
+ inputs=[video_input, seg_method, text_prompt, image_editor],
55
+ outputs=[output_video, mask_preview]
56
+ )
57
+
58
+ return demo
59
+
60
+ def process_video_enhanced(video, method, text_prompt, image_data):
61
+ if method == "Text Prompt":
62
+ # Use Grounding DINO + SAM2
63
+ points = text_to_points(text_prompt, video)
64
+ elif method == "Brush/Draw":
65
+ # Use drawn mask directly
66
+ mask = image_data_to_mask(image_data)
67
+ else:
68
+ # Use click points (original method)
69
+ points = extract_points_from_clicks(image_data)
70
+
71
+ # Process with VideoMaMa (existing pipeline)
72
+ return videomama_pipeline.process(video, points)
pipeline_svd_mask.py ADDED
@@ -0,0 +1,1038 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline_svd_masked.py
2
+
3
+ import inspect
4
+ from dataclasses import dataclass
5
+ from typing import Callable, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
11
+
12
+ from diffusers.image_processor import PipelineImageInput
13
+ from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
14
+ from diffusers.schedulers import EulerDiscreteScheduler
15
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from diffusers.video_processor import VideoProcessor
18
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
19
+
20
+ # Import necessary helpers from the original SVD pipeline
21
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
22
+ _append_dims,
23
+ retrieve_timesteps,
24
+ _resize_with_antialiasing,
25
+ )
26
+ import torch.nn.functional as F
27
+ from einops import rearrange
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+ EXAMPLE_DOC_STRING = """
33
+ Examples:
34
+ ```py
35
+ >>> from pipeline_svd_masked import StableVideoDiffusionPipelineWithMask
36
+ >>> from diffusers.utils import load_image, export_to_video
37
+
38
+ >>> # Load your fine-tuned UNet, VAE, etc.
39
+ >>> pipe = StableVideoDiffusionPipelineWithMask.from_pretrained(
40
+ ... "path/to/your/finetuned_model", torch_dtype=torch.float16, variant="fp16"
41
+ ... )
42
+ >>> pipe.to("cuda")
43
+
44
+ >>> # Load the conditioning image and the mask
45
+ >>> image = load_image("path/to/your/conditioning_image.png").resize((1024, 576))
46
+ >>> mask = load_image("path/to/your/mask_image.png").resize((1024, 576))
47
+
48
+ >>> # Generate frames
49
+ >>> frames = pipe(
50
+ ... image=image,
51
+ ... mask_image=mask,
52
+ ... num_frames=25,
53
+ ... decode_chunk_size=8
54
+ ... ).frames[0]
55
+
56
+ >>> export_to_video(frames, "generated_video.mp4", fps=7)
57
+ ```
58
+ """
59
+
60
+
61
+ @dataclass
62
+ class StableVideoDiffusionPipelineOutput(BaseOutput):
63
+ r"""
64
+ Output class for the custom Stable Video Diffusion pipeline.
65
+ Args:
66
+ frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
67
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape
68
+ `(batch_size, num_frames, height, width, num_channels)`.
69
+ """
70
+ frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
71
+
72
+
73
+ class StableVideoDiffusionPipelineWithMask(DiffusionPipeline):
74
+ r"""
75
+ A custom pipeline based on Stable Video Diffusion that accepts an additional mask for conditioning.
76
+ This pipeline is designed to work with a UNet fine-tuned to accept 12 input channels
77
+ (4 for noise, 4 for VAE-encoded condition image, 4 for VAE-encoded mask).
78
+ """
79
+
80
+ model_cpu_offload_seq = "image_encoder->unet->vae"
81
+ _callback_tensor_inputs = ["latents"]
82
+
83
+ def __init__(
84
+ self,
85
+ vae: AutoencoderKLTemporalDecoder,
86
+ image_encoder: CLIPVisionModelWithProjection,
87
+ unet: UNetSpatioTemporalConditionModel,
88
+ scheduler: EulerDiscreteScheduler,
89
+ feature_extractor: CLIPImageProcessor,
90
+ ):
91
+ super().__init__()
92
+
93
+ self.register_modules(
94
+ vae=vae,
95
+ image_encoder=image_encoder,
96
+ unet=unet,
97
+ scheduler=scheduler,
98
+ feature_extractor=feature_extractor,
99
+ )
100
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
101
+ self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
102
+
103
+ def _encode_image(
104
+ self,
105
+ image: PipelineImageInput,
106
+ device: Union[str, torch.device],
107
+ num_videos_per_prompt: int,
108
+ ) -> torch.Tensor:
109
+ dtype = next(self.image_encoder.parameters()).dtype
110
+
111
+ if not isinstance(image, torch.Tensor):
112
+ image = self.video_processor.pil_to_numpy(image)
113
+ image = self.video_processor.numpy_to_pt(image)
114
+
115
+ image = image * 2.0 - 1.0
116
+ image = _resize_with_antialiasing(image, (224, 224))
117
+ image = (image + 1.0) / 2.0
118
+
119
+ image = self.feature_extractor(
120
+ images=image,
121
+ do_normalize=True,
122
+ do_center_crop=False,
123
+ do_resize=False,
124
+ do_rescale=False,
125
+ return_tensors="pt",
126
+ ).pixel_values
127
+
128
+ image = image.to(device=device, dtype=dtype)
129
+ image_embeddings = self.image_encoder(image).image_embeds
130
+ image_embeddings = image_embeddings.unsqueeze(1)
131
+
132
+ bs_embed, seq_len, _ = image_embeddings.shape
133
+ image_embeddings = torch.zeros_like(image_embeddings)
134
+
135
+ return image_embeddings
136
+
137
+ def _encode_vae_image(
138
+ self,
139
+ image: torch.Tensor,
140
+ device: Union[str, torch.device],
141
+ num_videos_per_prompt: int,
142
+ ):
143
+ image = image.to(device=device, dtype=torch.float16)
144
+ image_latents = self.vae.encode(image).latent_dist.sample()
145
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
146
+ return image_latents
147
+
148
+ def _get_add_time_ids(
149
+ self,
150
+ fps: int,
151
+ motion_bucket_id: int,
152
+ noise_aug_strength: float,
153
+ dtype: torch.dtype,
154
+ batch_size: int,
155
+ num_videos_per_prompt: int,
156
+ ):
157
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
158
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
159
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
160
+ if expected_add_embed_dim != passed_add_embed_dim:
161
+ raise ValueError(
162
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created."
163
+ )
164
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
165
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
166
+ return add_time_ids
167
+
168
+ def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):
169
+ latents = latents.flatten(0, 1).to(dtype=torch.float16)
170
+ latents = 1 / self.vae.config.scaling_factor * latents
171
+ frames = []
172
+ for i in range(0, latents.shape[0], decode_chunk_size):
173
+ num_frames_in = latents[i: i + decode_chunk_size].shape[0]
174
+ frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=num_frames_in).sample
175
+ frames.append(frame)
176
+ frames = torch.cat(frames, dim=0)
177
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
178
+ frames = frames.float()
179
+ return frames
180
+
181
+ def check_inputs(self, image, height, width):
182
+ if (
183
+ not isinstance(image, torch.Tensor)
184
+ and not isinstance(image, PIL.Image.Image)
185
+ and not isinstance(image, list)
186
+ ):
187
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
188
+ if height % 8 != 0 or width % 8 != 0:
189
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
190
+
191
+ def prepare_latents(
192
+ self,
193
+ batch_size: int,
194
+ num_frames: int,
195
+ height: int,
196
+ width: int,
197
+ dtype: torch.dtype,
198
+ device: Union[str, torch.device],
199
+ generator: torch.Generator,
200
+ latents: Optional[torch.Tensor] = None,
201
+ initial_latents: Optional[torch.Tensor] = None,
202
+ denoising_strength: float = 1.0,
203
+ timestep: Optional[torch.Tensor] = None,
204
+ ):
205
+ num_channels_latents = self.unet.config.out_channels
206
+ shape = (
207
+ batch_size,
208
+ num_frames,
209
+ num_channels_latents,
210
+ height // self.vae_scale_factor,
211
+ width // self.vae_scale_factor,
212
+ )
213
+
214
+ if initial_latents is not None:
215
+ # Noise is added to the initial latents
216
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
217
+ # Get the initial latents at the given timestep
218
+ latents = self.scheduler.add_noise(initial_latents, noise, timestep)
219
+ else:
220
+ # Standard pure noise generation
221
+ if latents is None:
222
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
223
+ else:
224
+ latents = latents.to(device)
225
+ # Scale the initial noise by the standard deviation required by the scheduler
226
+ latents = latents * self.scheduler.init_noise_sigma
227
+
228
+ return latents
229
+
230
+ def _encode_video_vae(
231
+ self,
232
+ video_frames: torch.Tensor, # Expects (B, F, C, H, W)
233
+ device: Union[str, torch.device],
234
+ ):
235
+ video_frames = video_frames.to(device=device, dtype=self.vae.dtype)
236
+ batch_size, num_frames = video_frames.shape[:2]
237
+
238
+ # Reshape for VAE encoding
239
+ video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) # (B*F, C, H, W)
240
+ latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() # (B*F, C_latent, H_latent, W_latent)
241
+
242
+ # Reshape back to video format
243
+ latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) # (B, F, C_latent, H_latent, W_latent)
244
+
245
+ return latents
246
+
247
+ @torch.no_grad()
248
+ def __call__(
249
+ self,
250
+ image: Union[List[PIL.Image.Image], torch.Tensor],
251
+ mask_image: Union[List[PIL.Image.Image], torch.Tensor],
252
+ alpha_matte_image: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None,
253
+ denoising_strength: float = 0.7,
254
+ height: int = 576,
255
+ width: int = 1024,
256
+ num_frames: Optional[int] = None,
257
+ num_inference_steps: int = 30,
258
+ sigmas: Optional[List[float]] = None,
259
+ fps: int = 7,
260
+ motion_bucket_id: int = 127,
261
+ noise_aug_strength: float = 0.02,
262
+ decode_chunk_size: Optional[int] = None,
263
+ num_videos_per_prompt: Optional[int] = 1,
264
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
265
+ latents: Optional[torch.Tensor] = None,
266
+ output_type: Optional[str] = "pil",
267
+ return_dict: bool = True,
268
+ mask_noise_strength: float = 0.0,
269
+ ):
270
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
271
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
272
+
273
+ if num_frames is None:
274
+ if isinstance(image, list):
275
+ num_frames = len(image)
276
+ else:
277
+ num_frames = self.unet.config.num_frames
278
+
279
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
280
+
281
+ self.check_inputs(image, height, width)
282
+ self.check_inputs(mask_image, height, width)
283
+ if alpha_matte_image:
284
+ self.check_inputs(alpha_matte_image, height, width)
285
+
286
+ batch_size = 1
287
+ device = self._execution_device
288
+ dtype = self.unet.dtype
289
+
290
+ image_for_clip = image[0] if isinstance(image, list) else image[0]
291
+ image_embeddings = self._encode_image(image_for_clip, device, num_videos_per_prompt)
292
+
293
+ fps = fps - 1
294
+
295
+ image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device).unsqueeze(0)
296
+ mask_tensor = self.video_processor.preprocess(mask_image, height=height, width=width).to(device).unsqueeze(0)
297
+
298
+ noise = randn_tensor(image_tensor.shape, generator=generator, device=device, dtype=dtype)
299
+ image_tensor = image_tensor + noise_aug_strength * noise
300
+
301
+ conditional_latents = self._encode_video_vae(image_tensor, device)
302
+ conditional_latents = conditional_latents / self.vae.config.scaling_factor
303
+
304
+ if self.unet.config.in_channels == 12:
305
+ mask_latents = self._encode_video_vae(mask_tensor, device)
306
+ mask_latents = mask_latents / self.vae.config.scaling_factor
307
+ elif self.unet.config.in_channels == 9:
308
+ mask_tensor_gray = mask_tensor.mean(dim=2, keepdim=True)
309
+ binarized_mask = (mask_tensor_gray > 0.0).to(dtype)
310
+ b, f, c, h, w = binarized_mask.shape
311
+ binarized_mask_reshaped = binarized_mask.reshape(b * f, c, h, w)
312
+ target_size = (height // self.vae_scale_factor, width // self.vae_scale_factor)
313
+ interpolated_mask = F.interpolate(
314
+ binarized_mask_reshaped,
315
+ size=target_size,
316
+ mode='nearest',
317
+ )
318
+ mask_latents = interpolated_mask.reshape(b, f, *interpolated_mask.shape[1:])
319
+ else:
320
+ raise ValueError(f"Unsupported number of UNet input channels: {self.unet.config.in_channels}.")
321
+
322
+ if mask_noise_strength > 0.0:
323
+ mask_noise = randn_tensor(mask_latents.shape, generator=generator, device=device, dtype=dtype)
324
+ mask_latents = mask_latents + mask_noise_strength * mask_noise
325
+
326
+ added_time_ids = self._get_add_time_ids(
327
+ fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt
328
+ )
329
+ added_time_ids = added_time_ids.to(device)
330
+
331
+ # --- MODIFIED FOR ALPHA MATTE REFINEMENT ---
332
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
333
+
334
+ # self.scheduler.set_timesteps(num_inference_steps, device=device)
335
+ # timesteps = self.scheduler.timesteps
336
+ initial_latents = None
337
+
338
+ if alpha_matte_image is not None:
339
+ alpha_matte_tensor = self.video_processor.preprocess(alpha_matte_image, height=height, width=width).to(
340
+ device).unsqueeze(0)
341
+ initial_latents = self._encode_video_vae(alpha_matte_tensor, device)
342
+ initial_latents = initial_latents / self.vae.config.scaling_factor
343
+
344
+ # Adjust the number of steps and the timesteps to start from
345
+ t_start = max(num_inference_steps - int(num_inference_steps * denoising_strength), 0)
346
+ timesteps = timesteps[t_start:]
347
+ # We need the first timestep to add the correct amount of noise
348
+ start_timestep = timesteps[0]
349
+ else:
350
+ start_timestep = timesteps[0] # Not used, but for clarity
351
+
352
+ latents = self.prepare_latents(
353
+ batch_size * num_videos_per_prompt,
354
+ num_frames,
355
+ height,
356
+ width,
357
+ dtype,
358
+ device,
359
+ generator,
360
+ latents,
361
+ initial_latents=initial_latents,
362
+ denoising_strength=denoising_strength,
363
+ timestep=start_timestep if initial_latents is not None else None,
364
+ )
365
+
366
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
367
+ self._num_timesteps = len(timesteps)
368
+
369
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
370
+ for i, t in enumerate(timesteps):
371
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
372
+ latent_model_input = torch.cat([latent_model_input, conditional_latents, mask_latents], dim=2)
373
+
374
+ noise_pred = self.unet(
375
+ latent_model_input, t, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids,
376
+ return_dict=False
377
+ )[0]
378
+
379
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
380
+
381
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
382
+ progress_bar.update()
383
+
384
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
385
+ frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
386
+
387
+ self.maybe_free_model_hooks()
388
+
389
+ if not return_dict:
390
+ return frames
391
+ return StableVideoDiffusionPipelineOutput(frames=frames)
392
+
393
+
394
+ class StableVideoDiffusionPipelineOnestepWithMask(DiffusionPipeline):
395
+ r"""
396
+ A custom pipeline based on Stable Video Diffusion that accepts an additional mask for conditioning.
397
+ This pipeline is designed to work with a UNet fine-tuned to accept 12 input channels
398
+ (4 for noise, 4 for VAE-encoded condition image, 4 for VAE-encoded mask).
399
+ """
400
+
401
+ model_cpu_offload_seq = "image_encoder->unet->vae"
402
+ _callback_tensor_inputs = ["latents"]
403
+
404
+ def __init__(
405
+ self,
406
+ vae: AutoencoderKLTemporalDecoder,
407
+ image_encoder: CLIPVisionModelWithProjection,
408
+ unet: UNetSpatioTemporalConditionModel,
409
+ scheduler: EulerDiscreteScheduler,
410
+ feature_extractor: CLIPImageProcessor,
411
+ ):
412
+ super().__init__()
413
+
414
+ self.register_modules(
415
+ vae=vae,
416
+ image_encoder=image_encoder,
417
+ unet=unet,
418
+ scheduler=scheduler,
419
+ feature_extractor=feature_extractor,
420
+ )
421
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
422
+ self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
423
+
424
+ def _encode_image(
425
+ self,
426
+ image: PipelineImageInput,
427
+ device: Union[str, torch.device],
428
+ num_videos_per_prompt: int,
429
+ ) -> torch.Tensor:
430
+ dtype = next(self.image_encoder.parameters()).dtype
431
+
432
+ if not isinstance(image, torch.Tensor):
433
+ image = self.video_processor.pil_to_numpy(image)
434
+ image = self.video_processor.numpy_to_pt(image)
435
+
436
+ image = image * 2.0 - 1.0
437
+ image = _resize_with_antialiasing(image, (224, 224))
438
+ image = (image + 1.0) / 2.0
439
+
440
+ image = self.feature_extractor(
441
+ images=image,
442
+ do_normalize=True,
443
+ do_center_crop=False,
444
+ do_resize=False,
445
+ do_rescale=False,
446
+ return_tensors="pt",
447
+ ).pixel_values
448
+
449
+ image = image.to(device=device, dtype=dtype)
450
+ image_embeddings = self.image_encoder(image).image_embeds
451
+ image_embeddings = image_embeddings.unsqueeze(1)
452
+
453
+ bs_embed, seq_len, _ = image_embeddings.shape
454
+ image_embeddings = torch.zeros_like(image_embeddings)
455
+
456
+ return image_embeddings
457
+
458
+ def _encode_vae_image(
459
+ self,
460
+ image: torch.Tensor,
461
+ device: Union[str, torch.device],
462
+ num_videos_per_prompt: int,
463
+ ):
464
+ image = image.to(device=device, dtype=torch.float16)
465
+ image_latents = self.vae.encode(image).latent_dist.sample()
466
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
467
+ return image_latents
468
+
469
+ def _get_add_time_ids(
470
+ self,
471
+ fps: int,
472
+ motion_bucket_id: int,
473
+ noise_aug_strength: float,
474
+ dtype: torch.dtype,
475
+ batch_size: int,
476
+ num_videos_per_prompt: int,
477
+ ):
478
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
479
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
480
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
481
+ if expected_add_embed_dim != passed_add_embed_dim:
482
+ raise ValueError(
483
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created."
484
+ )
485
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
486
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
487
+ return add_time_ids
488
+
489
+ def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):
490
+ latents = latents.flatten(0, 1).to(dtype=torch.float16)
491
+ latents = 1 / self.vae.config.scaling_factor * latents
492
+ frames = []
493
+ for i in range(0, latents.shape[0], decode_chunk_size):
494
+ num_frames_in = latents[i: i + decode_chunk_size].shape[0]
495
+ frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=num_frames_in).sample
496
+ frames.append(frame)
497
+ frames = torch.cat(frames, dim=0)
498
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
499
+ frames = frames.float()
500
+ return frames
501
+
502
+ def check_inputs(self, image, height, width):
503
+ if (
504
+ not isinstance(image, torch.Tensor)
505
+ and not isinstance(image, PIL.Image.Image)
506
+ and not isinstance(image, list)
507
+ ):
508
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
509
+ if height % 8 != 0 or width % 8 != 0:
510
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
511
+
512
+ def prepare_latents(
513
+ self,
514
+ batch_size: int,
515
+ num_frames: int,
516
+ height: int,
517
+ width: int,
518
+ dtype: torch.dtype,
519
+ device: Union[str, torch.device],
520
+ generator: torch.Generator,
521
+ latents: Optional[torch.Tensor] = None,
522
+ ):
523
+ # The number of channels for the initial noise is based on the UNet's out_channels
524
+ num_channels_latents = self.unet.config.out_channels
525
+ shape = (
526
+ batch_size,
527
+ num_frames,
528
+ num_channels_latents,
529
+ height // self.vae_scale_factor,
530
+ width // self.vae_scale_factor,
531
+ )
532
+ if isinstance(generator, list) and len(generator) != batch_size:
533
+ raise ValueError(f"batch size {batch_size} must match the length of the generators {len(generator)}.")
534
+
535
+ if latents is None:
536
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
537
+ else:
538
+ latents = latents.to(device)
539
+
540
+ latents = latents * self.scheduler.init_noise_sigma
541
+ return latents
542
+
543
+ def _encode_video_vae(
544
+ self,
545
+ video_frames: torch.Tensor, # Expects (B, F, C, H, W)
546
+ device: Union[str, torch.device],
547
+ ):
548
+ video_frames = video_frames.to(device=device, dtype=self.vae.dtype)
549
+ batch_size, num_frames = video_frames.shape[:2]
550
+
551
+ # Reshape for VAE encoding
552
+ video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) # (B*F, C, H, W)
553
+ latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() # (B*F, C_latent, H_latent, W_latent)
554
+
555
+ # Reshape back to video format
556
+ latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) # (B, F, C_latent, H_latent, W_latent)
557
+
558
+ return latents
559
+
560
+ @torch.no_grad()
561
+ def __call__(
562
+ self,
563
+ image: Union[List[PIL.Image.Image], torch.Tensor],
564
+ mask_image: Union[List[PIL.Image.Image], torch.Tensor],
565
+ height: int = 576,
566
+ width: int = 1024,
567
+ num_frames: Optional[int] = None,
568
+ fps: int = 7,
569
+ motion_bucket_id: int = 127,
570
+ noise_aug_strength: float = 0.0,
571
+ decode_chunk_size: Optional[int] = None,
572
+ num_videos_per_prompt: Optional[int] = 1,
573
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
574
+ latents: Optional[torch.Tensor] = None,
575
+ output_type: Optional[str] = "pil",
576
+ return_dict: bool = True,
577
+ mask_noise_strength: float = 0.0,
578
+ ):
579
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
580
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
581
+
582
+ if num_frames is None:
583
+ if isinstance(image, list):
584
+ num_frames = len(image)
585
+ else:
586
+ num_frames = self.unet.config.num_frames
587
+
588
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
589
+
590
+ self.check_inputs(image, height, width)
591
+ self.check_inputs(mask_image, height, width)
592
+ if isinstance(image, list) and isinstance(mask_image, list):
593
+ if len(image) != len(mask_image):
594
+ raise ValueError("`image` and `mask_image` must have the same number of frames.")
595
+ if num_frames != len(image):
596
+ logger.warning(
597
+ f"Mismatch between `num_frames` ({num_frames}) and number of input images ({len(image)}). Using {len(image)}.")
598
+ num_frames = len(image)
599
+
600
+ batch_size = 1
601
+ device = self._execution_device
602
+ dtype = self.unet.dtype
603
+
604
+ image_for_clip = image[0] if isinstance(image, list) else image[0]
605
+ image_embeddings = self._encode_image(image_for_clip, device, num_videos_per_prompt)
606
+
607
+ fps = fps - 1
608
+
609
+ image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device).unsqueeze(0)
610
+ mask_tensor = self.video_processor.preprocess(mask_image, height=height, width=width).to(
611
+ device).unsqueeze(0)
612
+
613
+ noise = randn_tensor(image_tensor.shape, generator=generator, device=device, dtype=dtype)
614
+ image_tensor = image_tensor + noise_aug_strength * noise
615
+
616
+ conditional_latents = self._encode_video_vae(image_tensor, device)
617
+ conditional_latents = conditional_latents / self.vae.config.scaling_factor
618
+
619
+ if self.unet.config.in_channels == 12:
620
+ mask_latents = self._encode_video_vae(mask_tensor, device)
621
+ mask_latents = mask_latents / self.vae.config.scaling_factor
622
+ elif self.unet.config.in_channels == 9:
623
+ mask_tensor_gray = mask_tensor.mean(dim=2, keepdim=True)
624
+ binarized_mask = (mask_tensor_gray > 0.0).to(dtype)
625
+ b, f, c, h, w = binarized_mask.shape
626
+ binarized_mask_reshaped = binarized_mask.reshape(b * f, c, h, w)
627
+ target_size = (height // self.vae_scale_factor, width // self.vae_scale_factor)
628
+ interpolated_mask = F.interpolate(
629
+ binarized_mask_reshaped,
630
+ size=target_size,
631
+ mode='nearest',
632
+ )
633
+ mask_latents = interpolated_mask.reshape(b, f, *interpolated_mask.shape[1:])
634
+ else:
635
+ raise ValueError(
636
+ f"Unsupported number of UNet input channels: {self.unet.config.in_channels}. "
637
+ "This pipeline only supports 9 (for interpolated mask) or 12 (for VAE mask)."
638
+ )
639
+
640
+ if mask_noise_strength > 0.0:
641
+ mask_noise = randn_tensor(mask_latents.shape, generator=generator, device=device, dtype=dtype)
642
+ mask_latents = mask_latents + mask_noise_strength * mask_noise
643
+
644
+ added_time_ids = self._get_add_time_ids(
645
+ fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt
646
+ )
647
+ added_time_ids = added_time_ids.to(device)
648
+
649
+ # **MODIFIED FOR SINGLE-STEP**: Prepare initial noise
650
+ num_channels_latents = self.unet.config.out_channels
651
+ shape = (
652
+ batch_size * num_videos_per_prompt,
653
+ num_frames,
654
+ num_channels_latents,
655
+ height // self.vae_scale_factor,
656
+ width // self.vae_scale_factor,
657
+ )
658
+ if latents is None:
659
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
660
+
661
+ # **MODIFIED FOR SINGLE-STEP**: Set a fixed high timestep
662
+ timestep = torch.tensor([1.0], dtype=dtype, device=device) # Use a high sigma value
663
+
664
+ # **MODIFIED FOR SINGLE-STEP**: Single forward pass
665
+ latent_model_input = torch.cat([latents, conditional_latents, mask_latents], dim=2)
666
+
667
+ noise_pred = self.unet(
668
+ latent_model_input, timestep, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids,
669
+ return_dict=False
670
+ )[0]
671
+
672
+ # The model's prediction is the final denoised latent
673
+ denoised_latents = noise_pred
674
+
675
+ frames = self.decode_latents(denoised_latents, num_frames, decode_chunk_size)
676
+ frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
677
+
678
+ self.maybe_free_model_hooks()
679
+
680
+ if not return_dict:
681
+ return frames
682
+ return StableVideoDiffusionPipelineOutput(frames=frames)
683
+
684
+
685
+ class StableVideoDiffusionPipelineWithCrossAtnnMask(DiffusionPipeline):
686
+ model_cpu_offload_seq = "image_encoder->unet->vae"
687
+ _callback_tensor_inputs = ["latents"]
688
+
689
+ def __init__(
690
+ self,
691
+ vae: AutoencoderKLTemporalDecoder,
692
+ unet: UNetSpatioTemporalConditionModel,
693
+ scheduler: EulerDiscreteScheduler,
694
+ mask_projector: torch.nn.Module,
695
+ # CLIP models are not strictly needed for inference if embeddings are not used
696
+ image_encoder: CLIPVisionModelWithProjection = None,
697
+ feature_extractor: CLIPImageProcessor = None,
698
+ ):
699
+ super().__init__()
700
+ self.register_modules(
701
+ vae=vae,
702
+ unet=unet,
703
+ scheduler=scheduler,
704
+ mask_projector=mask_projector,
705
+ image_encoder=image_encoder,
706
+ feature_extractor=feature_extractor,
707
+ )
708
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
709
+ self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
710
+
711
+ def _encode_image_vae(self, image: torch.Tensor, device: Union[str, torch.device]):
712
+ image = image.to(device=device, dtype=self.vae.dtype)
713
+ latent = self.vae.encode(image).latent_dist.sample()
714
+ return latent
715
+
716
+ def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int):
717
+ latents = latents.flatten(0, 1).to(dtype=torch.float16)
718
+ latents = 1 / self.vae.config.scaling_factor * latents
719
+ frames = []
720
+ for i in range(0, latents.shape[0], decode_chunk_size):
721
+ frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=decode_chunk_size).sample
722
+ frames.append(frame)
723
+
724
+ frames = torch.cat(frames, dim=0)
725
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
726
+ frames = frames.float()
727
+ return frames
728
+
729
+ def _encode_video_vae(
730
+ self,
731
+ video_frames: torch.Tensor, # Expects (B, F, C, H, W)
732
+ device: Union[str, torch.device],
733
+ ):
734
+ video_frames = video_frames.to(device=device, dtype=self.vae.dtype)
735
+ batch_size, num_frames = video_frames.shape[:2]
736
+
737
+ # Reshape for VAE encoding
738
+ video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) # (B*F, C, H, W)
739
+ latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() # (B*F, C_latent, H_latent, W_latent)
740
+
741
+ # Reshape back to video format
742
+ latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) # (B, F, C_latent, H_latent, W_latent)
743
+
744
+ return latents
745
+
746
+ @torch.no_grad()
747
+ def __call__(
748
+ self,
749
+ image: Union[PIL.Image.Image, torch.Tensor], # Static image for appearance
750
+ mask_image: List[PIL.Image.Image], # Video mask for motion
751
+ height: int = 576,
752
+ width: int = 1024,
753
+ num_frames: Optional[int] = None,
754
+ num_inference_steps: int = 25,
755
+ fps: int = 7,
756
+ motion_bucket_id: int = 127,
757
+ noise_aug_strength: float = 0.0, # Noise is added to latents now
758
+ decode_chunk_size: Optional[int] = 8,
759
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
760
+ output_type: Optional[str] = "pil",
761
+ return_dict: bool = True,
762
+ ):
763
+ device = self._execution_device
764
+ dtype = self.unet.dtype
765
+ num_frames = num_frames if num_frames is not None else len(mask_image)
766
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
767
+
768
+ # 1. PREPARE STATIC IMAGE CONDITION
769
+ image_tensor = self.video_processor.preprocess(image, height, width).to(device).unsqueeze(0)
770
+ conditional_latents = self._encode_video_vae(image_tensor, device)
771
+ conditional_latents = conditional_latents / self.vae.config.scaling_factor
772
+
773
+ # 2. PREPARE MASK MOTION CONDITION
774
+ mask_tensor = self.video_processor.preprocess(mask_image, height, width)
775
+ if mask_tensor.shape[1] > 1:
776
+ mask_tensor = mask_tensor.mean(dim=1, keepdim=True)
777
+
778
+ # Reshape for projector: (T, C, H, W)
779
+ mask_for_projection = rearrange(mask_tensor, "f c h w -> f c h w").to(device, dtype)
780
+ encoder_hidden_states = self.mask_projector(mask_for_projection)
781
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(1) # (T, 1, D)
782
+ # Add batch dimension for UNet
783
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(0) # (1, T, 1, D)
784
+ # The UNet will handle flattening this to (B*T, 1, D) where B=1
785
+ # To be safe, we pass it pre-flattened.
786
+ encoder_hidden_states = rearrange(encoder_hidden_states, "b f s d -> (b f) s d")
787
+
788
+ # 3. PREPARE LATENTS
789
+ shape = (1, num_frames, self.unet.config.out_channels, height // self.vae_scale_factor,
790
+ width // self.vae_scale_factor)
791
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
792
+ if noise_aug_strength > 0:
793
+ latents += noise_aug_strength * randn_tensor(latents.shape, generator=generator, device=device,
794
+ dtype=dtype)
795
+ latents = latents * self.scheduler.init_noise_sigma
796
+
797
+ # 4. GET ADDED TIME IDS
798
+ # For pipeline, batch size is 1
799
+ added_time_ids = [fps - 1, motion_bucket_id, 0.0] # noise_aug_strength for add_time_ids is 0 for inference
800
+ added_time_ids = torch.tensor([added_time_ids], dtype=dtype, device=device)
801
+
802
+ # 5. DENOISING LOOP
803
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
804
+ timesteps = self.scheduler.timesteps
805
+
806
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
807
+ for t in timesteps:
808
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
809
+ unet_input = torch.cat([latent_model_input, conditional_latents], dim=2)
810
+
811
+ noise_pred = self.unet(
812
+ unet_input, t, encoder_hidden_states=encoder_hidden_states, added_time_ids=added_time_ids
813
+ ).sample
814
+
815
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
816
+ progress_bar.update()
817
+
818
+ # 6. DECODE
819
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
820
+ frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
821
+
822
+ if not return_dict:
823
+ return (frames,)
824
+ return StableVideoDiffusionPipelineOutput(frames=frames)
825
+
826
+
827
+ # pipeline.py
828
+
829
+ import torch
830
+ import torch.nn.functional as F
831
+ from PIL import Image
832
+ from einops import rearrange
833
+ from torchvision import transforms
834
+ from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
835
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
836
+
837
+
838
+ class VideoInferencePipeline:
839
+ """
840
+ A reusable pipeline for single-step video diffusion inference.
841
+
842
+ This class encapsulates the models and the core inference logic,
843
+ separating it from data loading and saving, which can vary between tasks.
844
+ """
845
+
846
+ def __init__(self, base_model_path: str, unet_checkpoint_path: str, device: str = "cuda",
847
+ weight_dtype: torch.dtype = torch.float16):
848
+ """
849
+ Loads all necessary models into memory.
850
+
851
+ Args:
852
+ base_model_path (str): Path to the base Stable Video Diffusion model.
853
+ unet_checkpoint_path (str): Path to the fine-tuned UNet checkpoint.
854
+ device (str): The device to run models on ('cuda' or 'cpu').
855
+ weight_dtype (torch.dtype): The precision for model weights (float16 or bfloat16).
856
+ """
857
+ print("--- Initializing Inference Pipeline and Loading Models ---")
858
+ self.device = torch.device(device if torch.cuda.is_available() else "cpu")
859
+ self.weight_dtype = weight_dtype
860
+
861
+ # Load models from pretrained paths
862
+ try:
863
+ self.feature_extractor = CLIPImageProcessor.from_pretrained(base_model_path, subfolder="feature_extractor")
864
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_model_path,
865
+ subfolder="image_encoder",
866
+ variant="fp16")
867
+ self.vae = AutoencoderKLTemporalDecoder.from_pretrained(base_model_path, subfolder="vae", variant="fp16")
868
+ self.unet = UNetSpatioTemporalConditionModel.from_pretrained(unet_checkpoint_path, subfolder="unet")
869
+ except Exception as e:
870
+ raise IOError(f"Fatal error loading models: {e}")
871
+
872
+ # Move models to the specified device and set to evaluation mode
873
+ self.image_encoder.to(self.device, dtype=self.weight_dtype).eval()
874
+ self.vae.to(self.device, dtype=self.weight_dtype).eval()
875
+ self.unet.to(self.device, dtype=self.weight_dtype).eval()
876
+
877
+ print(f"--- Models Loaded Successfully on {self.device} ---")
878
+
879
+ def run(self, cond_frames, mask_frames, seed=42, mask_cond_mode="vae", fps=7, motion_bucket_id=127,
880
+ noise_aug_strength=0.0):
881
+ """
882
+ Runs the core inference process on a sequence of conditioning and mask frames.
883
+
884
+ Args:
885
+ cond_frames (list[Image.Image]): List of PIL images for conditioning.
886
+ mask_frames (list[Image.Image]): List of PIL images for the masks.
887
+ seed (int): Random seed for generation.
888
+ mask_cond_mode (str): How the mask is conditioned ("vae" or "interpolate").
889
+ fps (int): Frames per second to condition the model with.
890
+ motion_bucket_id (int): Motion bucket ID for conditioning.
891
+ noise_aug_strength (float): Noise augmentation strength.
892
+
893
+ Returns:
894
+ list[Image.Image]: A list of the generated video frames as PIL Images.
895
+ """
896
+ # --- 1. Prepare Tensors ---
897
+ cond_video_tensor = self._pil_to_tensor(cond_frames).to(self.device)
898
+ mask_video_tensor = self._pil_to_tensor(mask_frames).to(self.device)
899
+
900
+ if mask_video_tensor.shape[2] != 3:
901
+ mask_video_tensor = mask_video_tensor.repeat(1, 1, 3, 1, 1)
902
+
903
+ with torch.no_grad():
904
+ # --- 2. Get CLIP Image Embeddings ---
905
+ first_frame_tensor = cond_video_tensor[:, 0, :, :, :]
906
+ pixel_values_for_clip = self._resize_with_antialiasing(first_frame_tensor, (224, 224))
907
+ pixel_values_for_clip = ((pixel_values_for_clip + 1.0) / 2.0).clamp(0, 1)
908
+ pixel_values = self.feature_extractor(images=pixel_values_for_clip, return_tensors="pt").pixel_values
909
+ image_embeddings = self.image_encoder(pixel_values.to(self.device, dtype=self.weight_dtype)).image_embeds
910
+ encoder_hidden_states = torch.zeros_like(image_embeddings).unsqueeze(1)
911
+
912
+ # --- 3. Prepare Latents ---
913
+ cond_latents = self._tensor_to_vae_latent(cond_video_tensor.to(self.weight_dtype))
914
+ cond_latents = cond_latents / self.vae.config.scaling_factor
915
+
916
+ if mask_cond_mode == "vae":
917
+ mask_latents = self._tensor_to_vae_latent(mask_video_tensor.to(self.weight_dtype))
918
+ mask_latents = mask_latents / self.vae.config.scaling_factor
919
+ elif mask_cond_mode == "interpolate":
920
+ target_shape = cond_latents.shape[-2:]
921
+ b, t, c, h, w = mask_video_tensor.shape
922
+ mask_video_reshaped = rearrange(mask_video_tensor, "b t c h w -> (b t) c h w")
923
+ interpolated_mask = F.interpolate(mask_video_reshaped, size=target_shape, mode='bilinear',
924
+ align_corners=False)
925
+ mask_latents = rearrange(interpolated_mask, "(b t) c h w -> b t c h w", b=b)
926
+ else:
927
+ raise ValueError(f"Unknown mask_cond_mode: {mask_cond_mode}")
928
+
929
+ # --- 4. Run UNet Single-Step Inference ---
930
+ generator = torch.Generator(device=self.device).manual_seed(seed)
931
+ noisy_latents = torch.randn(cond_latents.shape, generator=generator, device=self.device,
932
+ dtype=self.weight_dtype)
933
+ timesteps = torch.full((1,), 1.0, device=self.device, dtype=torch.long)
934
+ added_time_ids = self._get_add_time_ids(fps, motion_bucket_id, noise_aug_strength, batch_size=1)
935
+
936
+ unet_input = torch.cat([noisy_latents, cond_latents, mask_latents], dim=2)
937
+ pred_latents = self.unet(unet_input, timesteps, encoder_hidden_states, added_time_ids=added_time_ids).sample
938
+
939
+ # --- 5. Decode Latents to Video Frames ---
940
+ pred_latents = (1 / self.vae.config.scaling_factor) * pred_latents.squeeze(0)
941
+
942
+ frames = []
943
+ # Process in chunks to avoid VRAM issues, especially for long videos
944
+ for i in range(0, pred_latents.shape[0], 8):
945
+ chunk = pred_latents[i: i + 8]
946
+ decoded_chunk = self.vae.decode(chunk, num_frames=chunk.shape[0]).sample
947
+ frames.append(decoded_chunk)
948
+
949
+ video_tensor = torch.cat(frames, dim=0)
950
+ video_tensor = (video_tensor / 2.0 + 0.5).clamp(0, 1).mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
951
+
952
+ # Return a list of PIL images
953
+ return [transforms.ToPILImage()(frame) for frame in video_tensor]
954
+
955
+ def _pil_to_tensor(self, frames: list[Image.Image]):
956
+ """Converts a list of PIL images to a normalized video tensor."""
957
+ video_tensor = torch.stack([transforms.ToTensor()(f) for f in frames]).unsqueeze(0)
958
+ return video_tensor * 2.0 - 1.0
959
+
960
+ def _tensor_to_vae_latent(self, t: torch.Tensor):
961
+ """Encodes a video tensor into the VAE's latent space."""
962
+ video_length = t.shape[1]
963
+ t = rearrange(t, "b f c h w -> (b f) c h w")
964
+ latents = self.vae.encode(t).latent_dist.sample()
965
+ latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
966
+ return latents * self.vae.config.scaling_factor
967
+
968
+ def _get_add_time_ids(self, fps, motion_bucket_id, noise_aug_strength, batch_size):
969
+ """Creates the additional time IDs for conditioning the UNet."""
970
+ add_time_ids_list = [fps, motion_bucket_id, noise_aug_strength]
971
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids_list)
972
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
973
+ if expected_add_embed_dim != passed_add_embed_dim:
974
+ raise ValueError(
975
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created.")
976
+ add_time_ids = torch.tensor([add_time_ids_list], dtype=self.weight_dtype, device=self.device)
977
+ return add_time_ids.repeat(batch_size, 1)
978
+
979
+ def _resize_with_antialiasing(self, input_tensor, size, interpolation="bicubic", align_corners=True):
980
+ """
981
+ Resizes a tensor with anti-aliasing for CLIP input, mirroring k-diffusion.
982
+ This is a direct copy of the helper function from your original scripts.
983
+ """
984
+ h, w = input_tensor.shape[-2:]
985
+ factors = (h / size[0], w / size[1])
986
+ sigmas = (max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001))
987
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
988
+ if (ks[0] % 2) == 0: ks = ks[0] + 1, ks[1]
989
+ if (ks[1] % 2) == 0: ks = ks[0], ks[1] + 1
990
+
991
+ def _compute_padding(kernel_size):
992
+ computed = [k - 1 for k in kernel_size]
993
+ out_padding = 2 * len(kernel_size) * [0]
994
+ for i in range(len(kernel_size)):
995
+ computed_tmp = computed[-(i + 1)]
996
+ pad_front = computed_tmp // 2
997
+ pad_rear = computed_tmp - pad_front
998
+ out_padding[2 * i + 0] = pad_front
999
+ out_padding[2 * i + 1] = pad_rear
1000
+ return out_padding
1001
+
1002
+ def _filter2d(input_tensor, kernel):
1003
+ b, c, h, w = input_tensor.shape
1004
+ tmp_kernel = kernel[:, None, ...].to(device=input_tensor.device, dtype=input_tensor.dtype)
1005
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
1006
+ height, width = tmp_kernel.shape[-2:]
1007
+ padding_shape = _compute_padding([height, width])
1008
+ input_tensor_padded = F.pad(input_tensor, padding_shape, mode="reflect")
1009
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
1010
+ input_tensor_padded = input_tensor_padded.view(-1, tmp_kernel.size(0), input_tensor_padded.size(-2),
1011
+ input_tensor_padded.size(-1))
1012
+ output = F.conv2d(input_tensor_padded, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
1013
+ return output.view(b, c, h, w)
1014
+
1015
+ def _gaussian(window_size, sigma):
1016
+ if isinstance(sigma, float):
1017
+ sigma = torch.tensor([[sigma]])
1018
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(
1019
+ sigma.shape[0], -1)
1020
+ if window_size % 2 == 0:
1021
+ x = x + 0.5
1022
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
1023
+ return gauss / gauss.sum(-1, keepdim=True)
1024
+
1025
+ def _gaussian_blur2d(input_tensor, kernel_size, sigma):
1026
+ if isinstance(sigma, tuple):
1027
+ sigma = torch.tensor([sigma], dtype=input_tensor.dtype)
1028
+ else:
1029
+ sigma = sigma.to(dtype=input_tensor.dtype)
1030
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
1031
+ bs = sigma.shape[0]
1032
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
1033
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
1034
+ out_x = _filter2d(input_tensor, kernel_x[..., None, :])
1035
+ return _filter2d(out_x, kernel_y[..., None])
1036
+
1037
+ blurred_input = _gaussian_blur2d(input_tensor, ks, sigmas)
1038
+ return F.interpolate(blurred_input, size=size, mode=interpolation, align_corners=align_corners)
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Space Requirements for VideoMaMa Demo
2
+
3
+ # Core frameworks
4
+ torch>=2.0.0
5
+ torchvision>=0.15.0
6
+ diffusers>=0.24.0
7
+ transformers>=4.30.0
8
+
9
+ # Gradio for UI
10
+ gradio==5.12.0
11
+
12
+ # Image and video processing
13
+ opencv-python>=4.8.0
14
+ opencv-contrib-python>=4.8.0
15
+ Pillow>=10.0.0
16
+ numpy>=1.24.0
17
+ scipy>=1.10.0
18
+
19
+ # SAM2 dependencies
20
+ git+https://github.com/facebookresearch/sam2.git
21
+
22
+ # Additional utilities
23
+ accelerate>=0.20.0
24
+ einops>=0.6.0
25
+ tqdm>=4.65.0
26
+ safetensors>=0.3.0
27
+
28
+ # For video export
29
+ imageio>=2.31.0
30
+ imageio-ffmpeg>=0.4.9
31
+ pydantic==2.10.6
sam2_hiera_l.yaml ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Configuration for SAM2
2
+ # This file should be placed alongside the SAM2 checkpoint
3
+
4
+ # SAM 2 Hiera Large Configuration
5
+ model:
6
+ _target_: sam2.modeling.sam2_base.SAM2Base
7
+ image_encoder:
8
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3]
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ directly_add_no_mem_embed: true
94
+ use_high_res_features_in_sam: true
95
+ multimask_output_in_sam: true
96
+ multimask_min_pt_num: 0
97
+ multimask_max_pt_num: 1
98
+ multimask_output_for_tracking: true
99
+ use_multimask_token_for_obj_ptr: true
100
+ iou_prediction_use_sigmoid: True
101
+ memory_temporal_stride_for_eval: 1
102
+ non_overlap_masks_for_mem_enc: true
103
+ use_obj_ptrs_in_encoder: true
104
+ max_obj_ptrs_in_encoder: 16
105
+ add_tpos_enc_to_obj_ptrs: false
106
+ proj_tpos_enc_in_obj_ptrs: false
107
+ use_signed_tpos_enc_to_obj_ptrs: false
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ pred_obj_scores: true
110
+ pred_obj_scores_mlp: true
111
+ fixed_no_obj_ptr: true
112
+ soft_no_obj_ptr: false
113
+ use_mlp_for_obj_ptr_proj: true
114
+ no_obj_embed_spatial: true
115
+
116
+ sam_mask_decoder_extra_args:
117
+ dynamic_multimask_via_stability: true
118
+ dynamic_multimask_stability_delta: 0.05
119
+ dynamic_multimask_stability_thresh: 0.98
120
+ pred_obj_scores: true
121
+ pred_obj_scores_mlp: true
122
+ use_multimask_token_for_obj_ptr: true
123
+
124
+ compile_image_encoder: False
sam2_wrapper.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM2 Wrapper for Video Mask Tracking
3
+ Handles mask generation and propagation through video
4
+ """
5
+
6
+ import sys
7
+ sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2")
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ from PIL import Image
13
+ from pathlib import Path
14
+ from typing import List, Tuple
15
+ import tempfile
16
+ import shutil
17
+
18
+ from sam2.build_sam import build_sam2_video_predictor
19
+
20
+
21
+ class SAM2VideoTracker:
22
+ def __init__(self, checkpoint_path, config_file, device="cuda"):
23
+ """
24
+ Initialize SAM2 video tracker
25
+
26
+ Args:
27
+ checkpoint_path: Path to SAM2 checkpoint
28
+ config_file: Path to SAM2 config file
29
+ device: Device to run on
30
+ """
31
+ self.device = device
32
+ self.predictor = build_sam2_video_predictor(
33
+ config_file=config_file,
34
+ ckpt_path=checkpoint_path,
35
+ device=device
36
+ )
37
+ print(f"SAM2 video tracker initialized on {device}")
38
+
39
+ def track_video(self, frames: List[np.ndarray], points: List[List[int]],
40
+ labels: List[int]) -> List[np.ndarray]:
41
+ """
42
+ Track object through video using SAM2
43
+
44
+ Args:
45
+ frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
46
+ points: List of [x, y] coordinates for prompts
47
+ labels: List of labels (1 for positive, 0 for negative)
48
+
49
+ Returns:
50
+ masks: List of numpy arrays, [(H,W)]*n, uint8 binary masks
51
+ """
52
+ # Create temporary directory for frames
53
+ temp_dir = Path(tempfile.mkdtemp())
54
+ frames_dir = temp_dir / "frames"
55
+ frames_dir.mkdir(exist_ok=True)
56
+
57
+ try:
58
+ # Save frames to temp directory
59
+ print(f"Saving {len(frames)} frames to temporary directory...")
60
+ for i, frame in enumerate(frames):
61
+ frame_path = frames_dir / f"{i:05d}.jpg"
62
+ Image.fromarray(frame).save(frame_path, quality=95)
63
+
64
+ # Initialize SAM2 video predictor
65
+ print("Initializing SAM2 inference state...")
66
+ inference_state = self.predictor.init_state(video_path=str(frames_dir))
67
+
68
+ # Add prompts on first frame
69
+ points_array = np.array(points, dtype=np.float32)
70
+ labels_array = np.array(labels, dtype=np.int32)
71
+
72
+ print(f"Adding {len(points)} point prompts on first frame...")
73
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
74
+ inference_state=inference_state,
75
+ frame_idx=0,
76
+ obj_id=1,
77
+ points=points_array,
78
+ labels=labels_array,
79
+ )
80
+
81
+ # Propagate through video
82
+ print("Propagating masks through video...")
83
+ masks = []
84
+ for frame_idx, object_ids, mask_logits in self.predictor.propagate_in_video(inference_state):
85
+ # Get mask for object ID 1
86
+ # object_ids can be a tensor or a list
87
+ obj_ids_list = object_ids.tolist() if hasattr(object_ids, 'tolist') else object_ids
88
+
89
+ if 1 in obj_ids_list:
90
+ mask_idx = obj_ids_list.index(1)
91
+ mask = (mask_logits[mask_idx] > 0.0).cpu().numpy()
92
+ mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
93
+ masks.append(mask_uint8)
94
+ else:
95
+ # No mask for this frame, use empty mask
96
+ h, w = frames[0].shape[:2]
97
+ masks.append(np.zeros((h, w), dtype=np.uint8))
98
+
99
+ print(f"Generated {len(masks)} masks")
100
+ return masks
101
+
102
+ finally:
103
+ # Clean up temporary directory
104
+ shutil.rmtree(temp_dir, ignore_errors=True)
105
+
106
+ def get_first_frame_mask(self, frame: np.ndarray, points: List[List[int]],
107
+ labels: List[int]) -> np.ndarray:
108
+ """
109
+ Get mask for first frame only (for preview)
110
+
111
+ Args:
112
+ frame: np.ndarray, (H, W, 3), uint8 RGB frame
113
+ points: List of [x, y] coordinates
114
+ labels: List of labels (1 for positive, 0 for negative)
115
+
116
+ Returns:
117
+ mask: np.ndarray, (H, W), uint8 binary mask
118
+ """
119
+ # Create temporary directory
120
+ temp_dir = Path(tempfile.mkdtemp())
121
+ frames_dir = temp_dir / "frames"
122
+ frames_dir.mkdir(exist_ok=True)
123
+
124
+ try:
125
+ # Save single frame
126
+ frame_path = frames_dir / "00000.jpg"
127
+ Image.fromarray(frame).save(frame_path, quality=95)
128
+
129
+ # Initialize SAM2
130
+ inference_state = self.predictor.init_state(video_path=str(frames_dir))
131
+
132
+ # Add prompts
133
+ points_array = np.array(points, dtype=np.float32)
134
+ labels_array = np.array(labels, dtype=np.int32)
135
+
136
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
137
+ inference_state=inference_state,
138
+ frame_idx=0,
139
+ obj_id=1,
140
+ points=points_array,
141
+ labels=labels_array,
142
+ )
143
+
144
+ # Get mask
145
+ if len(out_mask_logits) > 0:
146
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
147
+ mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
148
+ return mask_uint8
149
+ else:
150
+ return np.zeros(frame.shape[:2], dtype=np.uint8)
151
+
152
+ finally:
153
+ shutil.rmtree(temp_dir, ignore_errors=True)
154
+
155
+
156
+ def load_sam2_tracker(device="cuda"):
157
+ """
158
+ Load SAM2 video tracker with pretrained weights
159
+
160
+ Args:
161
+ device: Device to run on
162
+
163
+ Returns:
164
+ SAM2VideoTracker instance
165
+ """
166
+ checkpoint_path = "checkpoints/sam2/sam2.1_hiera_large.pt"
167
+ config_file = "configs/sam2.1/sam2.1_hiera_l.yaml"
168
+
169
+ print(f"Loading SAM2 from {checkpoint_path}...")
170
+ tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
171
+
172
+ return tracker
sam2_wrapper_hf.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM2 Wrapper for Video Mask Tracking - Hugging Face Space Version
3
+ Handles mask generation and propagation through video
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ from pathlib import Path
9
+
10
+ # Add SAM2 to path if installed
11
+ try:
12
+ import sam2
13
+ except ImportError:
14
+ # Try to add from common locations
15
+ possible_paths = [
16
+ "/home/cvlab19/project/samuel/CVPR/sam2",
17
+ "./sam2"
18
+ ]
19
+ for path in possible_paths:
20
+ if os.path.exists(path):
21
+ sys.path.append(path)
22
+ break
23
+
24
+ import cv2
25
+ import numpy as np
26
+ import torch
27
+ from PIL import Image
28
+ from typing import List, Tuple
29
+ import tempfile
30
+ import shutil
31
+
32
+ from sam2.build_sam import build_sam2_video_predictor
33
+
34
+
35
+ class SAM2VideoTracker:
36
+ def __init__(self, checkpoint_path, config_file, device="cuda"):
37
+ """
38
+ Initialize SAM2 video tracker
39
+
40
+ Args:
41
+ checkpoint_path: Path to SAM2 checkpoint
42
+ config_file: Path to SAM2 config file
43
+ device: Device to run on
44
+ """
45
+ self.device = device
46
+ self.predictor = build_sam2_video_predictor(
47
+ config_file=config_file,
48
+ ckpt_path=checkpoint_path,
49
+ device=device
50
+ )
51
+ print(f"SAM2 video tracker initialized on {device}")
52
+
53
+ def track_video(self, frames: List[np.ndarray], points: List[List[int]],
54
+ labels: List[int]) -> List[np.ndarray]:
55
+ """
56
+ Track object through video using SAM2
57
+
58
+ Args:
59
+ frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
60
+ points: List of [x, y] coordinates for prompts
61
+ labels: List of labels (1 for positive, 0 for negative)
62
+
63
+ Returns:
64
+ masks: List of numpy arrays, [(H,W)]*n, uint8 binary masks
65
+ """
66
+ # Create temporary directory for frames
67
+ temp_dir = Path(tempfile.mkdtemp())
68
+ frames_dir = temp_dir / "frames"
69
+ frames_dir.mkdir(exist_ok=True)
70
+
71
+ try:
72
+ # Save frames to temp directory
73
+ print(f"Saving {len(frames)} frames to temporary directory...")
74
+ for i, frame in enumerate(frames):
75
+ frame_path = frames_dir / f"{i:05d}.jpg"
76
+ Image.fromarray(frame).save(frame_path, quality=95)
77
+
78
+ # Initialize SAM2 video predictor
79
+ print("Initializing SAM2 inference state...")
80
+ inference_state = self.predictor.init_state(video_path=str(frames_dir))
81
+
82
+ # Add prompts on first frame
83
+ points_array = np.array(points, dtype=np.float32)
84
+ labels_array = np.array(labels, dtype=np.int32)
85
+
86
+ print(f"Adding {len(points)} point prompts on first frame...")
87
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
88
+ inference_state=inference_state,
89
+ frame_idx=0,
90
+ obj_id=1,
91
+ points=points_array,
92
+ labels=labels_array,
93
+ )
94
+
95
+ # Propagate through video
96
+ print("Propagating masks through video...")
97
+ masks = []
98
+ for frame_idx, object_ids, mask_logits in self.predictor.propagate_in_video(inference_state):
99
+ # Get mask for object ID 1
100
+ obj_ids_list = object_ids.tolist() if hasattr(object_ids, 'tolist') else object_ids
101
+
102
+ if 1 in obj_ids_list:
103
+ mask_idx = obj_ids_list.index(1)
104
+ mask = (mask_logits[mask_idx] > 0.0).cpu().numpy()
105
+ mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
106
+ masks.append(mask_uint8)
107
+ else:
108
+ # No mask for this frame, use empty mask
109
+ h, w = frames[0].shape[:2]
110
+ masks.append(np.zeros((h, w), dtype=np.uint8))
111
+
112
+ print(f"Generated {len(masks)} masks")
113
+ return masks
114
+
115
+ finally:
116
+ # Clean up temporary directory
117
+ shutil.rmtree(temp_dir, ignore_errors=True)
118
+
119
+ def get_first_frame_mask(self, frame: np.ndarray, points: List[List[int]],
120
+ labels: List[int]) -> np.ndarray:
121
+ """
122
+ Get mask for first frame only (for preview)
123
+
124
+ Args:
125
+ frame: np.ndarray, (H, W, 3), uint8 RGB frame
126
+ points: List of [x, y] coordinates
127
+ labels: List of labels (1 for positive, 0 for negative)
128
+
129
+ Returns:
130
+ mask: np.ndarray, (H, W), uint8 binary mask
131
+ """
132
+ # Create temporary directory
133
+ temp_dir = Path(tempfile.mkdtemp())
134
+ frames_dir = temp_dir / "frames"
135
+ frames_dir.mkdir(exist_ok=True)
136
+
137
+ try:
138
+ # Save single frame
139
+ frame_path = frames_dir / "00000.jpg"
140
+ Image.fromarray(frame).save(frame_path, quality=95)
141
+
142
+ # Initialize SAM2
143
+ inference_state = self.predictor.init_state(video_path=str(frames_dir))
144
+
145
+ # Add prompts
146
+ points_array = np.array(points, dtype=np.float32)
147
+ labels_array = np.array(labels, dtype=np.int32)
148
+
149
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
150
+ inference_state=inference_state,
151
+ frame_idx=0,
152
+ obj_id=1,
153
+ points=points_array,
154
+ labels=labels_array,
155
+ )
156
+
157
+ # Get mask
158
+ if len(out_mask_logits) > 0:
159
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
160
+ mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
161
+ return mask_uint8
162
+ else:
163
+ return np.zeros(frame.shape[:2], dtype=np.uint8)
164
+
165
+ finally:
166
+ shutil.rmtree(temp_dir, ignore_errors=True)
167
+
168
+
169
+ def load_sam2_tracker(checkpoint_path=None, device="cuda"):
170
+ """
171
+ Load SAM2 video tracker with pretrained weights
172
+
173
+ Args:
174
+ checkpoint_path: Path to SAM2 checkpoint (if None, uses default location)
175
+ device: Device to run on
176
+
177
+ Returns:
178
+ SAM2VideoTracker instance
179
+ """
180
+ # Use provided path or default
181
+ if checkpoint_path is None:
182
+ checkpoint_path = "checkpoints/sam2.1_hiera_large.pt"
183
+
184
+ # Config file should be in the SAM2 repo
185
+ config_file = "configs/sam2.1/sam2.1_hiera_l.yaml"
186
+
187
+ # Check if we need to use the local yaml file
188
+ if not os.path.exists(config_file):
189
+ config_file = "sam2_hiera_l.yaml"
190
+
191
+ print(f"Loading SAM2 from {checkpoint_path}...")
192
+ print(f"Using config: {config_file}")
193
+
194
+ tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
195
+
196
+ return tracker
tools/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Tools module
tools/base_segmenter.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM2 Base Segmenter
3
+ Adapted from MatAnyone demo
4
+ """
5
+
6
+ import sys
7
+ sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2")
8
+
9
+ import torch
10
+ import numpy as np
11
+ from sam2.build_sam import build_sam2_video_predictor
12
+
13
+
14
+ class BaseSegmenter:
15
+ def __init__(self, SAM_checkpoint, model_type, device):
16
+ """
17
+ Initialize SAM2 segmenter
18
+
19
+ Args:
20
+ SAM_checkpoint: Path to SAM2 checkpoint
21
+ model_type: SAM2 model config file
22
+ device: Device to run on
23
+ """
24
+ self.device = device
25
+ self.model_type = model_type
26
+
27
+ # Build SAM2 video predictor
28
+ self.sam_predictor = build_sam2_video_predictor(
29
+ config_file=model_type,
30
+ ckpt_path=SAM_checkpoint,
31
+ device=device
32
+ )
33
+
34
+ self.orignal_image = None
35
+ self.inference_state = None
36
+
37
+ def set_image(self, image: np.ndarray):
38
+ """Set the current image for segmentation"""
39
+ self.orignal_image = image
40
+
41
+ def reset_image(self):
42
+ """Reset the current image"""
43
+ self.orignal_image = None
44
+ self.inference_state = None
45
+
46
+ def predict(self, prompts, prompt_type, multimask=True):
47
+ """
48
+ Predict mask from prompts
49
+
50
+ Args:
51
+ prompts: Dictionary with point_coords, point_labels, mask_input
52
+ prompt_type: 'point' or 'both'
53
+ multimask: Whether to return multiple masks
54
+
55
+ Returns:
56
+ masks, scores, logits
57
+ """
58
+ # For SAM2, we need to handle prompts differently
59
+ # This is simplified - actual implementation will use video predictor
60
+
61
+ # Placeholder - actual SAM2 prediction would go here
62
+ # For now, return dummy values
63
+ h, w = self.orignal_image.shape[:2]
64
+ dummy_mask = np.zeros((h, w), dtype=bool)
65
+ dummy_score = np.array([1.0])
66
+ dummy_logit = np.zeros((h, w), dtype=np.float32)
67
+
68
+ return np.array([dummy_mask]), dummy_score, np.array([dummy_logit])
tools/interact_tools.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM2 Interaction Tools
3
+ Handles SAM2 mask generation with user clicks
4
+ """
5
+
6
+ import sys
7
+ sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2")
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ from .base_segmenter import BaseSegmenter
12
+ from .painter import mask_painter, point_painter
13
+
14
+
15
+ mask_color = 3
16
+ mask_alpha = 0.7
17
+ contour_color = 1
18
+ contour_width = 5
19
+ point_color_ne = 8 # positive points
20
+ point_color_ps = 50 # negative points
21
+ point_alpha = 0.9
22
+ point_radius = 15
23
+
24
+
25
+ class SamControler:
26
+ def __init__(self, SAM_checkpoint, model_type, device):
27
+ """
28
+ Initialize SAM controller
29
+
30
+ Args:
31
+ SAM_checkpoint: Path to SAM2 checkpoint
32
+ model_type: SAM2 model config file
33
+ device: Device to run on
34
+ """
35
+ self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
36
+ self.device = device
37
+
38
+ def first_frame_click(self, image: np.ndarray, points: np.ndarray,
39
+ labels: np.ndarray, multimask=True, mask_color=3):
40
+ """
41
+ Generate mask from clicks on first frame
42
+
43
+ Args:
44
+ image: np.ndarray, (H, W, 3), RGB image
45
+ points: np.ndarray, (N, 2), [x, y] coordinates
46
+ labels: np.ndarray, (N,), 1 for positive, 0 for negative
47
+ multimask: bool, whether to generate multiple masks
48
+ mask_color: int, color ID for mask overlay
49
+
50
+ Returns:
51
+ mask: np.ndarray, (H, W), binary mask
52
+ logit: np.ndarray, (H, W), mask logits
53
+ painted_image: PIL.Image, visualization with mask and points
54
+ """
55
+ # Check if we have positive clicks
56
+ neg_flag = labels[-1]
57
+
58
+ if neg_flag == 1: # Has positive click
59
+ # First pass with points only
60
+ prompts = {
61
+ 'point_coords': points,
62
+ 'point_labels': labels,
63
+ }
64
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
65
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
66
+
67
+ # Refine with mask input
68
+ prompts = {
69
+ 'point_coords': points,
70
+ 'point_labels': labels,
71
+ 'mask_input': logit[None, :, :]
72
+ }
73
+ masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
74
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
75
+ else: # Only positive clicks
76
+ prompts = {
77
+ 'point_coords': points,
78
+ 'point_labels': labels,
79
+ }
80
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
81
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
82
+
83
+ # Paint mask on image
84
+ painted_image = mask_painter(
85
+ image,
86
+ mask.astype('uint8'),
87
+ mask_color,
88
+ mask_alpha,
89
+ contour_color,
90
+ contour_width
91
+ )
92
+
93
+ # Paint positive points (label > 0)
94
+ positive_points = np.squeeze(points[np.argwhere(labels > 0)], axis=1)
95
+ if len(positive_points) > 0:
96
+ painted_image = point_painter(
97
+ painted_image,
98
+ positive_points,
99
+ point_color_ne,
100
+ point_alpha,
101
+ point_radius,
102
+ contour_color,
103
+ contour_width
104
+ )
105
+
106
+ # Paint negative points (label < 1)
107
+ negative_points = np.squeeze(points[np.argwhere(labels < 1)], axis=1)
108
+ if len(negative_points) > 0:
109
+ painted_image = point_painter(
110
+ painted_image,
111
+ negative_points,
112
+ point_color_ps,
113
+ point_alpha,
114
+ point_radius,
115
+ contour_color,
116
+ contour_width
117
+ )
118
+
119
+ painted_image = Image.fromarray(painted_image)
120
+
121
+ return mask, logit, painted_image
tools/painter.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask and point painting utilities
3
+ Adapted from MatAnyone demo
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+
11
+ def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7,
12
+ contour_color=1, contour_width=5):
13
+ """
14
+ Paint mask on image with transparency
15
+
16
+ Args:
17
+ input_image: np.ndarray, (H, W, 3)
18
+ input_mask: np.ndarray, (H, W), binary mask
19
+ mask_color: int, color ID for mask
20
+ mask_alpha: float, transparency
21
+ contour_color: int, color ID for contour
22
+ contour_width: int, width of contour
23
+
24
+ Returns:
25
+ painted_image: np.ndarray, (H, W, 3)
26
+ """
27
+ assert input_image.shape[:2] == input_mask.shape, "Image and mask must have same dimensions"
28
+
29
+ # Color palette
30
+ palette = np.array([
31
+ [0, 0, 0], # 0: black
32
+ [255, 0, 0], # 1: red
33
+ [0, 255, 0], # 2: green
34
+ [0, 0, 255], # 3: blue
35
+ [255, 255, 0], # 4: yellow
36
+ [255, 0, 255], # 5: magenta
37
+ [0, 255, 255], # 6: cyan
38
+ [128, 128, 128], # 7: gray
39
+ [255, 165, 0], # 8: orange
40
+ [128, 0, 128], # 9: purple
41
+ ])
42
+
43
+ mask_color_rgb = palette[mask_color % len(palette)]
44
+ contour_color_rgb = palette[contour_color % len(palette)]
45
+
46
+ # Create colored mask
47
+ painted_image = input_image.copy()
48
+ colored_mask = np.zeros_like(input_image)
49
+ colored_mask[input_mask > 0] = mask_color_rgb
50
+
51
+ # Blend with alpha
52
+ mask_region = input_mask > 0
53
+ painted_image[mask_region] = (
54
+ painted_image[mask_region] * (1 - mask_alpha) +
55
+ colored_mask[mask_region] * mask_alpha
56
+ ).astype(np.uint8)
57
+
58
+ # Draw contour
59
+ if contour_width > 0:
60
+ contours, _ = cv2.findContours(
61
+ input_mask.astype(np.uint8),
62
+ cv2.RETR_EXTERNAL,
63
+ cv2.CHAIN_APPROX_SIMPLE
64
+ )
65
+ cv2.drawContours(
66
+ painted_image,
67
+ contours,
68
+ -1,
69
+ contour_color_rgb.tolist(),
70
+ contour_width
71
+ )
72
+
73
+ return painted_image
74
+
75
+
76
+ def point_painter(input_image, input_points, point_color=8, point_alpha=0.9,
77
+ point_radius=15, contour_color=2, contour_width=3):
78
+ """
79
+ Paint points on image
80
+
81
+ Args:
82
+ input_image: np.ndarray, (H, W, 3)
83
+ input_points: np.ndarray, (N, 2), [x, y] coordinates
84
+ point_color: int, color ID for points
85
+ point_alpha: float, transparency
86
+ point_radius: int, radius of point circles
87
+ contour_color: int, color ID for contour
88
+ contour_width: int, width of contour
89
+
90
+ Returns:
91
+ painted_image: np.ndarray, (H, W, 3)
92
+ """
93
+ if len(input_points) == 0:
94
+ return input_image
95
+
96
+ palette = np.array([
97
+ [0, 0, 0], # 0: black
98
+ [255, 0, 0], # 1: red
99
+ [0, 255, 0], # 2: green
100
+ [0, 0, 255], # 3: blue
101
+ [255, 255, 0], # 4: yellow
102
+ [255, 0, 255], # 5: magenta
103
+ [0, 255, 255], # 6: cyan
104
+ [128, 128, 128], # 7: gray
105
+ [255, 165, 0], # 8: orange
106
+ [128, 0, 128], # 9: purple
107
+ ])
108
+
109
+ point_color_rgb = palette[point_color % len(palette)]
110
+ contour_color_rgb = palette[contour_color % len(palette)]
111
+
112
+ painted_image = input_image.copy()
113
+
114
+ for point in input_points:
115
+ x, y = int(point[0]), int(point[1])
116
+
117
+ # Draw filled circle with alpha blending
118
+ overlay = painted_image.copy()
119
+ cv2.circle(overlay, (x, y), point_radius, point_color_rgb.tolist(), -1)
120
+ cv2.addWeighted(overlay, point_alpha, painted_image, 1 - point_alpha, 0, painted_image)
121
+
122
+ # Draw contour
123
+ if contour_width > 0:
124
+ cv2.circle(painted_image, (x, y), point_radius, contour_color_rgb.tolist(), contour_width)
125
+
126
+ return painted_image
videomama_wrapper.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VideoMaMa Inference Wrapper
3
+ Handles video matting with mask conditioning
4
+ """
5
+
6
+ import sys
7
+ sys.path.append("../")
8
+ sys.path.append("../../")
9
+
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image
13
+ from pathlib import Path
14
+ from typing import List
15
+ import tqdm
16
+
17
+ from pipeline_svd_mask import VideoInferencePipeline
18
+
19
+
20
+ def videomama(pipeline, frames_np, mask_frames_np):
21
+ """
22
+ Run VideoMaMa inference on video frames with mask conditioning
23
+
24
+ Args:
25
+ pipeline: VideoInferencePipeline instance
26
+ frames_np: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
27
+ mask_frames_np: List of numpy arrays, [(H,W)]*n, uint8 grayscale masks
28
+
29
+ Returns:
30
+ output_frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB outputs
31
+ """
32
+ # Convert numpy arrays to PIL Images
33
+ frames_pil = [Image.fromarray(f) for f in frames_np]
34
+ mask_frames_pil = [Image.fromarray(m, mode='L') for m in mask_frames_np]
35
+
36
+ # Resize to model input size
37
+ target_width, target_height = 1024, 576
38
+ frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR)
39
+ for f in frames_pil]
40
+ masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR)
41
+ for m in mask_frames_pil]
42
+
43
+ # Run inference
44
+ print(f"Running VideoMaMa inference on {len(frames_resized)} frames...")
45
+ output_frames_pil = pipeline.run(
46
+ cond_frames=frames_resized,
47
+ mask_frames=masks_resized,
48
+ seed=42,
49
+ mask_cond_mode="vae"
50
+ )
51
+
52
+ # Resize back to original resolution
53
+ original_size = frames_pil[0].size
54
+ output_frames_resized = [f.resize(original_size, Image.Resampling.BILINEAR)
55
+ for f in output_frames_pil]
56
+
57
+ # Convert back to numpy arrays
58
+ output_frames_np = [np.array(f) for f in output_frames_resized]
59
+
60
+ return output_frames_np
61
+
62
+
63
+ def load_videomama_pipeline(device="cuda"):
64
+ """
65
+ Load VideoMaMa pipeline with pretrained weights
66
+
67
+ Args:
68
+ device: Device to run on
69
+
70
+ Returns:
71
+ VideoInferencePipeline instance
72
+ """
73
+ # Local paths for testing
74
+ base_model_path = "checkpoints/stable-video-diffusion-img2vid-xt"
75
+ unet_checkpoint_path = "checkpoints/VideoMaMa"
76
+
77
+ print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...")
78
+
79
+ pipeline = VideoInferencePipeline(
80
+ base_model_path=base_model_path,
81
+ unet_checkpoint_path=unet_checkpoint_path,
82
+ weight_dtype=torch.float16,
83
+ device=device
84
+ )
85
+
86
+ print("VideoMaMa pipeline loaded successfully!")
87
+
88
+ return pipeline
videomama_wrapper_hf.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VideoMaMa Inference Wrapper - Hugging Face Space Version
3
+ Handles video matting with mask conditioning
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ from pathlib import Path
9
+
10
+ # Add parent directories to path for imports
11
+ sys.path.append(str(Path(__file__).parent))
12
+ sys.path.append(str(Path(__file__).parent.parent))
13
+
14
+ import torch
15
+ import numpy as np
16
+ from PIL import Image
17
+ from typing import List
18
+
19
+ from pipeline_svd_mask import VideoInferencePipeline
20
+
21
+
22
+ def videomama(pipeline, frames_np, mask_frames_np):
23
+ """
24
+ Run VideoMaMa inference on video frames with mask conditioning
25
+
26
+ Args:
27
+ pipeline: VideoInferencePipeline instance
28
+ frames_np: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
29
+ mask_frames_np: List of numpy arrays, [(H,W)]*n, uint8 grayscale masks
30
+
31
+ Returns:
32
+ output_frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB outputs
33
+ """
34
+ # Convert numpy arrays to PIL Images
35
+ frames_pil = [Image.fromarray(f) for f in frames_np]
36
+ mask_frames_pil = [Image.fromarray(m, mode='L') for m in mask_frames_np]
37
+
38
+ # Resize to model input size
39
+ target_width, target_height = 1024, 576
40
+ frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR)
41
+ for f in frames_pil]
42
+ masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR)
43
+ for m in mask_frames_pil]
44
+
45
+ # Run inference
46
+ print(f"Running VideoMaMa inference on {len(frames_resized)} frames...")
47
+ output_frames_pil = pipeline.run(
48
+ cond_frames=frames_resized,
49
+ mask_frames=masks_resized,
50
+ seed=42,
51
+ mask_cond_mode="vae"
52
+ )
53
+
54
+ # Resize back to original resolution
55
+ original_size = frames_pil[0].size
56
+ output_frames_resized = [f.resize(original_size, Image.Resampling.BILINEAR)
57
+ for f in output_frames_pil]
58
+
59
+ # Convert back to numpy arrays
60
+ output_frames_np = [np.array(f) for f in output_frames_resized]
61
+
62
+ return output_frames_np
63
+
64
+
65
+ def load_videomama_pipeline(base_model_path=None, unet_checkpoint_path=None, device="cuda"):
66
+ """
67
+ Load VideoMaMa pipeline with pretrained weights
68
+
69
+ Args:
70
+ base_model_path: Path to SVD base model (if None, uses default)
71
+ unet_checkpoint_path: Path to VideoMaMa UNet checkpoint (if None, uses default)
72
+ device: Device to run on
73
+
74
+ Returns:
75
+ VideoInferencePipeline instance
76
+ """
77
+ # Use provided paths or defaults
78
+ if base_model_path is None:
79
+ base_model_path = "checkpoints/stable-video-diffusion-img2vid-xt"
80
+
81
+ if unet_checkpoint_path is None:
82
+ unet_checkpoint_path = "checkpoints/videomama"
83
+
84
+ # Check if paths exist
85
+ if not os.path.exists(base_model_path):
86
+ raise FileNotFoundError(
87
+ f"SVD base model not found at {base_model_path}. "
88
+ f"Please ensure models are downloaded correctly."
89
+ )
90
+
91
+ if not os.path.exists(unet_checkpoint_path):
92
+ raise FileNotFoundError(
93
+ f"VideoMaMa checkpoint not found at {unet_checkpoint_path}. "
94
+ f"Please upload your VideoMaMa model to Hugging Face Hub and update the download logic."
95
+ )
96
+
97
+ print(f"Loading VideoMaMa pipeline...")
98
+ print(f" Base model: {base_model_path}")
99
+ print(f" UNet checkpoint: {unet_checkpoint_path}")
100
+
101
+ pipeline = VideoInferencePipeline(
102
+ base_model_path=base_model_path,
103
+ unet_checkpoint_path=unet_checkpoint_path,
104
+ weight_dtype=torch.float16,
105
+ device=device
106
+ )
107
+
108
+ print("VideoMaMa pipeline loaded successfully!")
109
+
110
+ return pipeline