MogensR commited on
Commit
7cb91e0
Β·
verified Β·
1 Parent(s): eee126e

Create video_pipeline.py

Browse files
Files changed (1) hide show
  1. video_pipeline.py +469 -0
video_pipeline.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Video Processing Pipeline
4
+ Two-stage processing: SAM2+MatAnyone β†’ Transparent β†’ Composite
5
+ Includes temporal smoothing to eliminate jitter/shaking
6
+ """
7
+
8
+ import os
9
+ import time
10
+ import tempfile
11
+ import shutil
12
+ import gc
13
+ import logging
14
+ from pathlib import Path
15
+ import cv2
16
+ import numpy as np
17
+ from collections import deque
18
+ import torch
19
+ import streamlit as st
20
+
21
+ from models import (
22
+ load_sam2_predictor,
23
+ load_matanyone_processor,
24
+ torch_memory_manager,
25
+ get_memory_usage,
26
+ clear_model_cache
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Persistent temp dir
32
+ TMP_DIR = Path("tmp")
33
+ TMP_DIR.mkdir(parents=True, exist_ok=True)
34
+
35
+ # ============================================================================
36
+ # SAM2 Mask Generation
37
+ # ============================================================================
38
+
39
+ def generate_mask_from_video_first_frame(video_path, sam2_predictor):
40
+ """
41
+ Generate segmentation mask for the first frame using SAM2.
42
+ This mask is used as seed for MatAnyone's temporal propagation.
43
+ """
44
+ try:
45
+ with torch_memory_manager():
46
+ cap = cv2.VideoCapture(video_path)
47
+ ret, frame = cap.read()
48
+ cap.release()
49
+
50
+ if not ret:
51
+ logger.error("Failed to read video frame")
52
+ return None
53
+
54
+ # Resize frame if too large to save memory
55
+ h, w = frame.shape[:2]
56
+ max_size = 1080
57
+ if max(h, w) > max_size:
58
+ scale = max_size / max(h, w)
59
+ new_w, new_h = int(w * scale), int(h * scale)
60
+ frame = cv2.resize(frame, (new_w, new_h))
61
+ logger.info(f"Resized frame from {w}x{h} to {new_w}x{new_h}")
62
+
63
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64
+
65
+ # Use SAM2 to generate mask
66
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
67
+ sam2_predictor.set_image(frame_rgb)
68
+
69
+ # Use center point as default prompt
70
+ h, w = frame_rgb.shape[:2]
71
+ center_point = np.array([[w//2, h//2]], dtype=np.float32)
72
+ center_label = np.array([1], dtype=np.int32)
73
+
74
+ masks, scores, logits = sam2_predictor.predict(
75
+ point_coords=center_point,
76
+ point_labels=center_label,
77
+ multimask_output=True
78
+ )
79
+
80
+ # Select best mask based on score
81
+ best_mask = masks[np.argmax(scores)]
82
+ return best_mask.astype(np.uint8) * 255
83
+
84
+ except Exception as e:
85
+ logger.error(f"Failed to generate mask: {e}")
86
+ return None
87
+
88
+ # ============================================================================
89
+ # TEMPORAL SMOOTHING - Fixes the shaking issue
90
+ # ============================================================================
91
+
92
+ def smooth_alpha_video(alpha_video_path, output_path, window_size=5):
93
+ """
94
+ Apply temporal smoothing to alpha video to reduce jitter/shaking.
95
+
96
+ This averages each alpha frame with its neighbors to eliminate
97
+ the frame-to-frame instability that causes the shaking effect.
98
+
99
+ Args:
100
+ alpha_video_path: Path to MatAnyone's alpha output video
101
+ output_path: Path for smoothed alpha video
102
+ window_size: Number of frames to average (default 5)
103
+ - 3: Minimal smoothing, fastest
104
+ - 5: Balanced (recommended)
105
+ - 7: Maximum smoothing, may blur fast motion
106
+
107
+ Returns:
108
+ Path to smoothed alpha video
109
+ """
110
+ logger.info(f"🎬 Applying temporal smoothing to reduce jitter (window={window_size})")
111
+
112
+ try:
113
+ cap = cv2.VideoCapture(alpha_video_path)
114
+ fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
115
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
116
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
117
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
118
+
119
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
120
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=False)
121
+
122
+ # Rolling buffer for temporal averaging
123
+ frame_buffer = deque(maxlen=window_size)
124
+
125
+ frame_count = 0
126
+ while True:
127
+ ret, frame = cap.read()
128
+ if not ret:
129
+ break
130
+
131
+ # Convert to grayscale if needed
132
+ if len(frame.shape) == 3:
133
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
134
+
135
+ # Add to buffer
136
+ frame_buffer.append(frame.astype(np.float32))
137
+
138
+ # Average all frames in buffer
139
+ smoothed = np.mean(frame_buffer, axis=0).astype(np.uint8)
140
+
141
+ out.write(smoothed)
142
+ frame_count += 1
143
+
144
+ # Periodic memory cleanup
145
+ if frame_count % 30 == 0:
146
+ gc.collect()
147
+
148
+ cap.release()
149
+ out.release()
150
+
151
+ logger.info(f"βœ… Temporal smoothing complete: {frame_count} frames processed")
152
+ return output_path
153
+
154
+ except Exception as e:
155
+ logger.error(f"Temporal smoothing failed: {e}")
156
+ # Return original path if smoothing fails
157
+ return alpha_video_path
158
+
159
+ # ============================================================================
160
+ # Transparent Video Creation
161
+ # ============================================================================
162
+
163
+ def create_transparent_mov(foreground_path, alpha_path, temp_dir):
164
+ """
165
+ Create a .mov file with alpha channel from foreground and alpha videos.
166
+ Uses PNG codec to preserve alpha channel.
167
+ """
168
+ try:
169
+ output_path = str(temp_dir / "transparent.mov")
170
+
171
+ # Read videos
172
+ fg_cap = cv2.VideoCapture(foreground_path)
173
+ alpha_cap = cv2.VideoCapture(alpha_path)
174
+
175
+ # Get video properties
176
+ fps = int(fg_cap.get(cv2.CAP_PROP_FPS)) or 30
177
+ width = int(fg_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
178
+ height = int(fg_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
179
+
180
+ # Use PNG codec for alpha channel support
181
+ fourcc = cv2.VideoWriter_fourcc(*'png ')
182
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), True)
183
+
184
+ frame_count = 0
185
+ while True:
186
+ ret_fg, fg_frame = fg_cap.read()
187
+ ret_alpha, alpha_frame = alpha_cap.read()
188
+
189
+ if not ret_fg or not ret_alpha:
190
+ break
191
+
192
+ # Convert alpha to single channel if needed
193
+ if len(alpha_frame.shape) == 3:
194
+ alpha_frame = cv2.cvtColor(alpha_frame, cv2.COLOR_BGR2GRAY)
195
+
196
+ # Create RGBA frame
197
+ rgba_frame = np.zeros((height, width, 4), dtype=np.uint8)
198
+ rgba_frame[:, :, :3] = fg_frame # RGB channels
199
+ rgba_frame[:, :, 3] = alpha_frame # Alpha channel
200
+
201
+ # Convert RGBA to BGRA for OpenCV
202
+ bgra_frame = cv2.cvtColor(rgba_frame, cv2.COLOR_RGBA2BGRA)
203
+ out.write(bgra_frame)
204
+
205
+ frame_count += 1
206
+ if frame_count % 10 == 0:
207
+ gc.collect()
208
+
209
+ fg_cap.release()
210
+ alpha_cap.release()
211
+ out.release()
212
+
213
+ logger.info(f"Created transparent MOV: {frame_count} frames")
214
+ return output_path if os.path.exists(output_path) else None
215
+
216
+ except Exception as e:
217
+ logger.error(f"Failed to create transparent MOV: {e}")
218
+ return None
219
+
220
+ # ============================================================================
221
+ # STAGE 1: Create Transparent Video (with smoothing fix)
222
+ # ============================================================================
223
+
224
+ def stage1_create_transparent_video(input_file):
225
+ """
226
+ STAGE 1: Create transparent video using SAM2 + MatAnyone.
227
+
228
+ Pipeline:
229
+ 1. Generate first-frame mask with SAM2
230
+ 2. Process video with MatAnyone (temporal propagation)
231
+ 3. Apply temporal smoothing to alpha channel (FIXES SHAKING)
232
+ 4. Create transparent .mov file
233
+ """
234
+ logger.info("🎬 Starting Stage 1: Create transparent video")
235
+
236
+ # Check memory
237
+ memory_info = get_memory_usage()
238
+ if memory_info.get('gpu_free', 0) < 2.0:
239
+ st.warning("⚠️ Low GPU memory detected. Processing may be slower.")
240
+ clear_model_cache()
241
+
242
+ try:
243
+ progress_bar = st.progress(0)
244
+ status_text = st.empty()
245
+
246
+ def update_progress(progress, message):
247
+ progress = max(0, min(1, progress))
248
+ progress_bar.progress(progress)
249
+ gpu_mem = get_memory_usage().get('gpu_allocated', 0)
250
+ status_text.text(f"Stage 1: {message} | GPU: {gpu_mem:.1f}GB")
251
+ logger.info(f"Stage 1 [{progress:.0%}]: {message}")
252
+
253
+ # Load models
254
+ update_progress(0.05, "Loading SAM2 model...")
255
+ sam2_predictor = load_sam2_predictor()
256
+ if sam2_predictor is None:
257
+ st.error("❌ Failed to load SAM2 model")
258
+ return None
259
+
260
+ update_progress(0.1, "Loading MatAnyone model...")
261
+ matanyone_processor = load_matanyone_processor()
262
+ if matanyone_processor is None:
263
+ st.error("❌ Failed to load MatAnyone model")
264
+ return None
265
+
266
+ # Process video
267
+ with tempfile.TemporaryDirectory() as temp_dir:
268
+ temp_dir = Path(temp_dir)
269
+ input_path = str(temp_dir / "input.mp4")
270
+
271
+ # Save input video
272
+ with open(input_path, "wb") as f:
273
+ f.write(input_file.getvalue())
274
+
275
+ update_progress(0.2, "Generating first-frame segmentation mask...")
276
+
277
+ # Generate mask using SAM2
278
+ with torch_memory_manager():
279
+ mask = generate_mask_from_video_first_frame(input_path, sam2_predictor)
280
+
281
+ if mask is None:
282
+ st.error("❌ Failed to generate mask")
283
+ return None
284
+
285
+ mask_path = str(temp_dir / "mask.png")
286
+ cv2.imwrite(mask_path, mask)
287
+ logger.info(f"First-frame mask saved: {mask_path}")
288
+
289
+ update_progress(0.4, "Running MatAnyone temporal propagation...")
290
+
291
+ # Process with MatAnyone
292
+ try:
293
+ with torch_memory_manager():
294
+ foreground_path, alpha_path = matanyone_processor.process_video(
295
+ input_path=input_path,
296
+ mask_path=mask_path,
297
+ output_path=str(temp_dir),
298
+ max_size=720 # Limit resolution for memory efficiency
299
+ )
300
+
301
+ logger.info(f"MatAnyone complete - Foreground: {foreground_path}, Alpha: {alpha_path}")
302
+
303
+ # πŸ”§ FIX: Apply temporal smoothing to alpha channel
304
+ update_progress(0.6, "Applying temporal smoothing to eliminate jitter...")
305
+
306
+ smoothed_alpha_path = str(temp_dir / "alpha_smoothed.mp4")
307
+ alpha_path = smooth_alpha_video(alpha_path, smoothed_alpha_path, window_size=5)
308
+
309
+ logger.info("βœ… Temporal smoothing applied - shaking should be eliminated")
310
+
311
+ update_progress(0.8, "Creating transparent .mov file...")
312
+
313
+ # Create transparent video
314
+ transparent_path = create_transparent_mov(foreground_path, alpha_path, temp_dir)
315
+
316
+ if transparent_path and os.path.exists(transparent_path):
317
+ # Copy to persistent location
318
+ persist_path = TMP_DIR / "transparent_video.mov"
319
+ shutil.copyfile(transparent_path, persist_path)
320
+
321
+ update_progress(1.0, "βœ… Transparent video created successfully!")
322
+ time.sleep(0.5)
323
+ return str(persist_path)
324
+ else:
325
+ st.error("❌ Failed to create transparent video")
326
+ return None
327
+
328
+ except Exception as e:
329
+ logger.error(f"MatAnyone processing failed: {e}", exc_info=True)
330
+ st.error(f"❌ MatAnyone processing failed: {e}")
331
+ return None
332
+
333
+ except Exception as e:
334
+ logger.error(f"Stage 1 error: {e}", exc_info=True)
335
+ st.error(f"❌ Stage 1 failed: {e}")
336
+
337
+ # Show memory info for debugging
338
+ try:
339
+ memory_info = get_memory_usage()
340
+ st.info(f"Memory at failure - GPU: {memory_info.get('gpu_allocated', 0):.1f}GB, "
341
+ f"RAM: {memory_info.get('ram_used', 0):.1f}GB")
342
+ except:
343
+ pass
344
+
345
+ return None
346
+
347
+ finally:
348
+ logger.info("Stage 1 cleanup...")
349
+ if torch.cuda.is_available():
350
+ torch.cuda.empty_cache()
351
+ gc.collect()
352
+
353
+ # ============================================================================
354
+ # STAGE 2: Composite with Background
355
+ # ============================================================================
356
+
357
+ def stage2_composite_background(transparent_video_path, background, bg_type):
358
+ """
359
+ STAGE 2: Composite transparent video with new background.
360
+ Fast compositing that can be repeated with different backgrounds.
361
+ """
362
+ logger.info("🎬 Starting Stage 2: Composite with background")
363
+
364
+ try:
365
+ progress_bar = st.progress(0)
366
+ status_text = st.empty()
367
+
368
+ def update_progress(progress, message):
369
+ progress = max(0, min(1, progress))
370
+ progress_bar.progress(progress)
371
+ status_text.text(f"Stage 2: {message}")
372
+
373
+ with tempfile.TemporaryDirectory() as temp_dir:
374
+ temp_dir = Path(temp_dir)
375
+
376
+ update_progress(0.2, "Loading transparent video...")
377
+
378
+ # Read transparent video
379
+ cap = cv2.VideoCapture(transparent_video_path)
380
+ fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
381
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
382
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
383
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
384
+
385
+ update_progress(0.4, "Preparing background...")
386
+
387
+ # Prepare background
388
+ if bg_type == "image" and background is not None:
389
+ bg_array = np.array(background)
390
+ if len(bg_array.shape) == 3 and bg_array.shape[2] == 3:
391
+ bg_array = cv2.cvtColor(bg_array, cv2.COLOR_RGB2BGR)
392
+ elif len(bg_array.shape) == 3 and bg_array.shape[2] == 4:
393
+ bg_array = cv2.cvtColor(bg_array, cv2.COLOR_RGBA2BGR)
394
+ bg_resized = cv2.resize(bg_array, (width, height))
395
+ elif bg_type == "color":
396
+ # Parse hex color
397
+ color_hex = st.session_state.bg_color.lstrip('#')
398
+ r = int(color_hex[0:2], 16)
399
+ g = int(color_hex[2:4], 16)
400
+ b = int(color_hex[4:6], 16)
401
+ bg_resized = np.full((height, width, 3), (b, g, r), dtype=np.uint8)
402
+ else:
403
+ # Default green screen
404
+ bg_resized = np.full((height, width, 3), (0, 255, 0), dtype=np.uint8)
405
+
406
+ # Create output video
407
+ output_path = str(temp_dir / "final_output.mp4")
408
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
409
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
410
+
411
+ update_progress(0.6, "Compositing frames...")
412
+
413
+ frame_count = 0
414
+ while True:
415
+ ret, frame = cap.read()
416
+ if not ret:
417
+ break
418
+
419
+ # Extract alpha channel (BGRA format)
420
+ if frame.shape[2] == 4:
421
+ bgr_frame = frame[:, :, :3]
422
+ alpha_channel = frame[:, :, 3]
423
+ else:
424
+ # Fallback: assume full opacity
425
+ bgr_frame = frame
426
+ alpha_channel = np.full((height, width), 255, dtype=np.uint8)
427
+
428
+ # Normalize alpha to 0-1 range
429
+ alpha_norm = alpha_channel.astype(np.float32) / 255.0
430
+ alpha_norm = np.expand_dims(alpha_norm, axis=2)
431
+
432
+ # Composite: result = foreground * alpha + background * (1 - alpha)
433
+ fg_float = bgr_frame.astype(np.float32)
434
+ bg_float = bg_resized.astype(np.float32)
435
+
436
+ result = fg_float * alpha_norm + bg_float * (1 - alpha_norm)
437
+ result = result.astype(np.uint8)
438
+
439
+ out.write(result)
440
+ frame_count += 1
441
+
442
+ # Update progress
443
+ if total_frames > 0 and frame_count % 5 == 0:
444
+ progress = 0.6 + 0.3 * (frame_count / total_frames)
445
+ update_progress(progress, f"Compositing frame {frame_count}/{total_frames}")
446
+
447
+ if frame_count % 10 == 0:
448
+ gc.collect()
449
+
450
+ cap.release()
451
+ out.release()
452
+
453
+ logger.info(f"Compositing complete: {frame_count} frames")
454
+
455
+ if os.path.exists(output_path):
456
+ # Copy to persistent location
457
+ persist_path = TMP_DIR / "final_video.mp4"
458
+ shutil.copyfile(output_path, persist_path)
459
+
460
+ update_progress(1.0, "βœ… Compositing complete!")
461
+ time.sleep(0.5)
462
+ return str(persist_path)
463
+ else:
464
+ return None
465
+
466
+ except Exception as e:
467
+ logger.error(f"Stage 2 error: {e}", exc_info=True)
468
+ st.error(f"❌ Stage 2 failed: {e}")
469
+ return None