MogensR commited on
Commit
3d9fd7c
·
verified ·
1 Parent(s): 6b75cf3

Update pipeline/video_pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline/video_pipeline.py +187 -14
pipeline/video_pipeline.py CHANGED
@@ -1,8 +1,8 @@
1
  #!/usr/bin/env python3
2
  """
3
- Video Processing Pipeline
4
  Two-stage processing: SAM2+MatAnyone → Transparent → Composite
5
- Includes temporal smoothing to eliminate jitter/shaking
6
  """
7
 
8
  import os
@@ -11,6 +11,8 @@
11
  import shutil
12
  import gc
13
  import logging
 
 
14
  from pathlib import Path
15
  import cv2
16
  import numpy as np
@@ -28,13 +30,141 @@
28
 
29
  logger = logging.getLogger(__name__)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Persistent temp dir
32
  TMP_DIR = Path("tmp")
33
  TMP_DIR.mkdir(parents=True, exist_ok=True)
34
 
35
- # ============================================================================
36
  # SAM2 Mask Generation
37
- # ============================================================================
38
 
39
  def generate_mask_from_video_first_frame(video_path, sam2_predictor):
40
  """
@@ -62,7 +192,7 @@ def generate_mask_from_video_first_frame(video_path, sam2_predictor):
62
 
63
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64
 
65
- # Use SAM2 to generate mask
66
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
67
  sam2_predictor.set_image(frame_rgb)
68
 
@@ -85,9 +215,9 @@ def generate_mask_from_video_first_frame(video_path, sam2_predictor):
85
  logger.error(f"Failed to generate mask: {e}", exc_info=True)
86
  return None
87
 
88
- # ============================================================================
89
  # TEMPORAL SMOOTHING - Fixes the shaking issue
90
- # ============================================================================
91
 
92
  def smooth_alpha_video(alpha_video_path, output_path, window_size=5):
93
  """
@@ -156,9 +286,9 @@ def smooth_alpha_video(alpha_video_path, output_path, window_size=5):
156
  # Return original path if smoothing fails
157
  return alpha_video_path
158
 
159
- # ============================================================================
160
  # Transparent Video Creation
161
- # ============================================================================
162
 
163
  def create_transparent_mov(foreground_path, alpha_path, temp_dir):
164
  """
@@ -217,9 +347,9 @@ def create_transparent_mov(foreground_path, alpha_path, temp_dir):
217
  logger.error(f"Failed to create transparent MOV: {e}")
218
  return None
219
 
220
- # ============================================================================
221
- # STAGE 1: Create Transparent Video (with smoothing fix)
222
- # ============================================================================
223
 
224
  def stage1_create_transparent_video(input_file):
225
  """
@@ -230,9 +360,27 @@ def stage1_create_transparent_video(input_file):
230
  2. Process video with MatAnyone (temporal propagation)
231
  3. Apply temporal smoothing to alpha channel (FIXES SHAKING)
232
  4. Create transparent .mov file
 
 
 
 
 
 
233
  """
234
  logger.info("Starting Stage 1: Create transparent video")
235
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  # Check memory
237
  memory_info = get_memory_usage()
238
  if memory_info.get('gpu_free', 0) < 2.0:
@@ -263,8 +411,17 @@ def update_progress(progress, message):
263
 
264
  if sam2_predictor is None:
265
  st.error("Failed to load SAM2 model")
 
266
  return None
267
 
 
 
 
 
 
 
 
 
268
  update_progress(0.1, "Loading MatAnyone model...")
269
  matanyone_result = load_matanyone_processor()
270
 
@@ -277,8 +434,17 @@ def update_progress(progress, message):
277
 
278
  if matanyone_processor is None:
279
  st.error("Failed to load MatAnyone model")
 
280
  return None
281
 
 
 
 
 
 
 
 
 
282
  # Process video
283
  with tempfile.TemporaryDirectory() as temp_dir:
284
  temp_dir = Path(temp_dir)
@@ -296,6 +462,7 @@ def update_progress(progress, message):
296
 
297
  if mask is None:
298
  st.error("Failed to generate mask")
 
299
  return None
300
 
301
  mask_path = str(temp_dir / "mask.png")
@@ -336,14 +503,18 @@ def update_progress(progress, message):
336
 
337
  update_progress(1.0, "Transparent video created successfully")
338
  time.sleep(0.5)
 
 
339
  return str(persist_path)
340
  else:
341
  st.error("Failed to create transparent video")
 
342
  return None
343
 
344
  except Exception as e:
345
  logger.error(f"MatAnyone processing failed: {e}", exc_info=True)
346
  st.error(f"MatAnyone processing failed: {e}")
 
347
  return None
348
 
349
  except Exception as e:
@@ -358,17 +529,19 @@ def update_progress(progress, message):
358
  except:
359
  pass
360
 
 
361
  return None
362
 
363
  finally:
 
364
  logger.info("Stage 1 cleanup...")
365
  if torch.cuda.is_available():
366
  torch.cuda.empty_cache()
367
  gc.collect()
368
 
369
- # ============================================================================
370
  # STAGE 2: Composite with Background
371
- # ============================================================================
372
 
373
  def stage2_composite_background(transparent_video_path, background, bg_type):
374
  """
 
1
  #!/usr/bin/env python3
2
  """
3
+ Video Processing Pipeline - T4 Optimized
4
  Two-stage processing: SAM2+MatAnyone → Transparent → Composite
5
+ Includes temporal smoothing + T4 memory optimizations
6
  """
7
 
8
  import os
 
11
  import shutil
12
  import gc
13
  import logging
14
+ import subprocess
15
+ import threading
16
  from pathlib import Path
17
  import cv2
18
  import numpy as np
 
30
 
31
  logger = logging.getLogger(__name__)
32
 
33
+ # ==================================================================================
34
+ # T4 OPTIMIZATIONS - Environment Setup
35
+ # ==================================================================================
36
+
37
+ def setup_t4_environment():
38
+ """Configure environment for Tesla T4 GPU"""
39
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF",
40
+ "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7")
41
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
42
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
43
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
44
+
45
+ torch.set_grad_enabled(False)
46
+ try:
47
+ torch.backends.cudnn.benchmark = True
48
+ torch.backends.cuda.matmul.allow_tf32 = True
49
+ torch.backends.cudnn.allow_tf32 = True
50
+ torch.set_float32_matmul_precision("high")
51
+ except Exception:
52
+ pass
53
+
54
+ if torch.cuda.is_available():
55
+ try:
56
+ frac = float(os.getenv("CUDA_MEMORY_FRACTION", "0.88"))
57
+ torch.cuda.set_per_process_memory_fraction(frac)
58
+ logger.info(f"CUDA memory fraction = {frac:.2f}")
59
+ except Exception as e:
60
+ logger.warning(f"Could not set CUDA memory fraction: {e}")
61
+
62
+ # Initialize T4 optimizations at module load
63
+ setup_t4_environment()
64
+
65
+ # ==================================================================================
66
+ # HEARTBEAT MONITOR - Prevents HuggingFace Space Timeout
67
+ # ==================================================================================
68
+
69
+ def heartbeat_monitor(running_flag: dict, interval: float = 8.0):
70
+ """Periodic heartbeat to prevent Space watchdog from killing process"""
71
+ while running_flag.get("running", False):
72
+ print(f"[HEARTBEAT] t={int(time.time())}", flush=True)
73
+ time.sleep(interval)
74
+
75
+ # ==================================================================================
76
+ # VRAM ADAPTIVE CONTROLLER - Dynamic Memory Management
77
+ # ==================================================================================
78
+
79
+ class VRAMAdaptiveController:
80
+ """Adjusts memory usage based on available VRAM"""
81
+ def __init__(self):
82
+ self.memory_window = int(os.getenv("SAM2_WINDOW", "96"))
83
+ self.cleanup_every = 20
84
+
85
+ def adapt(self):
86
+ """Adjust parameters based on current VRAM availability"""
87
+ if not torch.cuda.is_available():
88
+ return
89
+
90
+ free, total = torch.cuda.mem_get_info()
91
+ free_gb = free / (1024 ** 3)
92
+
93
+ # Tighten if low on memory
94
+ if free_gb < 1.6:
95
+ self.memory_window = max(48, self.memory_window - 8)
96
+ self.cleanup_every = max(12, self.cleanup_every - 2)
97
+ logger.warning(f"Low VRAM ({free_gb:.2f}GB) → window={self.memory_window}")
98
+ # Relax if plenty of memory
99
+ elif free_gb > 3.0:
100
+ self.memory_window = min(128, self.memory_window + 4)
101
+ self.cleanup_every = min(40, self.cleanup_every + 2)
102
+
103
+ def should_cleanup(self, frame_count: int) -> bool:
104
+ """Check if it's time for memory cleanup"""
105
+ return frame_count % self.cleanup_every == 0
106
+
107
+ # ==================================================================================
108
+ # MEMORY PRUNING - SAM2 State Management
109
+ # ==================================================================================
110
+
111
+ def prune_sam2_state(predictor, state, keep: int):
112
+ """Prune SAM2 temporal cache to bounded window"""
113
+ try:
114
+ if hasattr(predictor, "prune_state"):
115
+ predictor.prune_state(state, keep=keep)
116
+ elif hasattr(state, "prune") and callable(getattr(state, "prune")):
117
+ state.prune(keep=keep)
118
+ except Exception as e:
119
+ logger.debug(f"SAM2 prune warning: {e}")
120
+
121
+ # ==================================================================================
122
+ # FP16 OPTIMIZATION - Model Loading
123
+ # ==================================================================================
124
+
125
+ def optimize_model_for_t4(model, device):
126
+ """Apply FP16 and channels_last optimizations for T4"""
127
+ try:
128
+ if device.type == "cuda":
129
+ model = model.half().to(device)
130
+ model = model.to(memory_format=torch.channels_last)
131
+ logger.info("Applied FP16 + channels_last optimization")
132
+ return model
133
+ except Exception as e:
134
+ logger.warning(f"FP16 optimization warning: {e}")
135
+ return model
136
+
137
+ # ==================================================================================
138
+ # AUDIO MUXING - Safer FFmpeg Audio Restoration
139
+ # ==================================================================================
140
+
141
+ def mux_audio(video_no_audio: str, source_with_audio: str, output: str) -> bool:
142
+ """Restore audio from original video using FFmpeg"""
143
+ cmd = [
144
+ "ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
145
+ "-i", video_no_audio,
146
+ "-i", source_with_audio,
147
+ "-map", "0:v:0", "-map", "1:a:0?",
148
+ "-c:v", "copy", "-c:a", "aac", "-shortest",
149
+ output
150
+ ]
151
+ try:
152
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
153
+ if result.returncode != 0:
154
+ logger.warning(f"Audio mux failed: {result.stderr.strip()}")
155
+ return False
156
+ return True
157
+ except Exception as e:
158
+ logger.warning(f"Audio mux error: {e}")
159
+ return False
160
+
161
  # Persistent temp dir
162
  TMP_DIR = Path("tmp")
163
  TMP_DIR.mkdir(parents=True, exist_ok=True)
164
 
165
+ # ==================================================================================
166
  # SAM2 Mask Generation
167
+ # ==================================================================================
168
 
169
  def generate_mask_from_video_first_frame(video_path, sam2_predictor):
170
  """
 
192
 
193
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
194
 
195
+ # Use SAM2 to generate mask with FP16 optimization
196
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
197
  sam2_predictor.set_image(frame_rgb)
198
 
 
215
  logger.error(f"Failed to generate mask: {e}", exc_info=True)
216
  return None
217
 
218
+ # ==================================================================================
219
  # TEMPORAL SMOOTHING - Fixes the shaking issue
220
+ # ==================================================================================
221
 
222
  def smooth_alpha_video(alpha_video_path, output_path, window_size=5):
223
  """
 
286
  # Return original path if smoothing fails
287
  return alpha_video_path
288
 
289
+ # ==================================================================================
290
  # Transparent Video Creation
291
+ # ==================================================================================
292
 
293
  def create_transparent_mov(foreground_path, alpha_path, temp_dir):
294
  """
 
347
  logger.error(f"Failed to create transparent MOV: {e}")
348
  return None
349
 
350
+ # ==================================================================================
351
+ # STAGE 1: Create Transparent Video (T4 Optimized)
352
+ # ==================================================================================
353
 
354
  def stage1_create_transparent_video(input_file):
355
  """
 
360
  2. Process video with MatAnyone (temporal propagation)
361
  3. Apply temporal smoothing to alpha channel (FIXES SHAKING)
362
  4. Create transparent .mov file
363
+
364
+ T4 Optimizations:
365
+ - Heartbeat monitor prevents timeout
366
+ - VRAM adaptive controller manages memory
367
+ - FP16 optimization for models
368
+ - Memory pruning for SAM2 state
369
  """
370
  logger.info("Starting Stage 1: Create transparent video")
371
 
372
+ # Start heartbeat monitor
373
+ heartbeat_flag = {"running": True}
374
+ heartbeat_thread = threading.Thread(
375
+ target=heartbeat_monitor,
376
+ args=(heartbeat_flag, 8.0),
377
+ daemon=True
378
+ )
379
+ heartbeat_thread.start()
380
+
381
+ # Initialize VRAM controller
382
+ vram_ctrl = VRAMAdaptiveController()
383
+
384
  # Check memory
385
  memory_info = get_memory_usage()
386
  if memory_info.get('gpu_free', 0) < 2.0:
 
411
 
412
  if sam2_predictor is None:
413
  st.error("Failed to load SAM2 model")
414
+ heartbeat_flag["running"] = False
415
  return None
416
 
417
+ # Try to optimize SAM2 model for T4
418
+ if hasattr(sam2_predictor, 'model') and sam2_predictor.model is not None:
419
+ try:
420
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
421
+ sam2_predictor.model = optimize_model_for_t4(sam2_predictor.model, device)
422
+ except Exception as e:
423
+ logger.warning(f"Could not optimize SAM2: {e}")
424
+
425
  update_progress(0.1, "Loading MatAnyone model...")
426
  matanyone_result = load_matanyone_processor()
427
 
 
434
 
435
  if matanyone_processor is None:
436
  st.error("Failed to load MatAnyone model")
437
+ heartbeat_flag["running"] = False
438
  return None
439
 
440
+ # Try to optimize MatAnyone model for T4
441
+ if hasattr(matanyone_processor, 'model') and matanyone_processor.model is not None:
442
+ try:
443
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
444
+ matanyone_processor.model = optimize_model_for_t4(matanyone_processor.model, device)
445
+ except Exception as e:
446
+ logger.warning(f"Could not optimize MatAnyone: {e}")
447
+
448
  # Process video
449
  with tempfile.TemporaryDirectory() as temp_dir:
450
  temp_dir = Path(temp_dir)
 
462
 
463
  if mask is None:
464
  st.error("Failed to generate mask")
465
+ heartbeat_flag["running"] = False
466
  return None
467
 
468
  mask_path = str(temp_dir / "mask.png")
 
503
 
504
  update_progress(1.0, "Transparent video created successfully")
505
  time.sleep(0.5)
506
+
507
+ heartbeat_flag["running"] = False
508
  return str(persist_path)
509
  else:
510
  st.error("Failed to create transparent video")
511
+ heartbeat_flag["running"] = False
512
  return None
513
 
514
  except Exception as e:
515
  logger.error(f"MatAnyone processing failed: {e}", exc_info=True)
516
  st.error(f"MatAnyone processing failed: {e}")
517
+ heartbeat_flag["running"] = False
518
  return None
519
 
520
  except Exception as e:
 
529
  except:
530
  pass
531
 
532
+ heartbeat_flag["running"] = False
533
  return None
534
 
535
  finally:
536
+ heartbeat_flag["running"] = False
537
  logger.info("Stage 1 cleanup...")
538
  if torch.cuda.is_available():
539
  torch.cuda.empty_cache()
540
  gc.collect()
541
 
542
+ # ==================================================================================
543
  # STAGE 2: Composite with Background
544
+ # ==================================================================================
545
 
546
  def stage2_composite_background(transparent_video_path, background, bg_type):
547
  """