MogensR commited on
Commit
3723e02
Β·
1 Parent(s): 39f0b54
Files changed (1) hide show
  1. pipeline.py +415 -568
pipeline.py CHANGED
@@ -1,256 +1,245 @@
1
  #!/usr/bin/env python3
2
  """
3
- pipeline.py - Core video background replacement processing
4
- Direct SAM2 + MatAnyone implementation with zero abstraction layers
5
- Processes video frame-by-frame: person detection -> mask generation -> background replacement
 
 
 
 
 
 
 
 
 
 
6
  """
7
 
8
  import os
 
9
  import cv2
10
  import time
11
  import uuid
 
 
12
  import shutil
 
13
  import tempfile
14
  import subprocess
 
15
  import numpy as np
16
  from PIL import Image
17
- import logging
18
- import gc
19
  from pathlib import Path
20
  from typing import Optional, Tuple, Dict, Any, Callable
 
21
 
22
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
23
 
24
- # ================================================================================================
25
- # VERSION VALIDATION AND LAZY IMPORTS
26
- # ================================================================================================
 
 
 
 
 
 
 
 
27
 
28
- def validate_pytorch_environment():
29
- """Validate PyTorch installation and versions before loading heavy models"""
30
  try:
31
- import torch
32
- torch_version = torch.__version__
33
- logger.info(f"PyTorch version: {torch_version}")
34
-
35
- if not torch.cuda.is_available():
36
- raise RuntimeError("CUDA not available - GPU processing required")
37
-
38
- cuda_version = torch.version.cuda
39
- cudnn_version = torch.backends.cudnn.version()
40
- logger.info(f"CUDA version: {cuda_version}")
41
- logger.info(f"cuDNN version: {cudnn_version}")
42
-
43
- # Test basic CUDA operations
44
  try:
45
- device = torch.device('cuda')
46
- test_tensor = torch.randn(100, 100).to(device)
47
- result = torch.mm(test_tensor, test_tensor.t())
48
- logger.info("CUDA basic operations test: PASSED")
49
- except Exception as cuda_test_error:
50
- logger.error(f"CUDA operations test FAILED: {cuda_test_error}")
51
- raise RuntimeError(f"CUDA incompatibility detected: {cuda_test_error}")
52
-
53
- # Version compatibility warnings
54
- torch_major = int(torch_version.split('.')[0])
55
- torch_minor = int(torch_version.split('.')[1])
56
-
57
- if torch_major == 2 and torch_minor >= 8:
58
- logger.warning(f"PyTorch {torch_version} is very new - may have compatibility issues")
59
-
60
- if torch_major < 2 or (torch_major == 2 and torch_minor < 3):
61
- raise RuntimeError(f"PyTorch {torch_version} too old for SAM2. Need >= 2.3.0")
62
-
63
- # GPU memory and capabilities
64
- total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
65
- gpu_name = torch.cuda.get_device_name(0)
66
- compute_capability = torch.cuda.get_device_capability(0)
67
-
68
- logger.info(f"GPU: {gpu_name}")
69
- logger.info(f"Compute capability: {compute_capability}")
70
- logger.info(f"Total GPU memory: {total_memory:.1f}GB")
71
-
72
- if total_memory < 8.0:
73
- logger.warning(f"Low GPU memory: {total_memory:.1f}GB. May fail on large videos.")
74
-
75
- return device, torch_version
76
-
77
- except Exception as e:
78
- logger.error(f"PyTorch environment validation failed: {e}")
79
- raise
80
 
81
- def lazy_import_sam2():
82
- """Lazy import SAM2 with error handling"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  try:
84
- logger.info("Lazy importing SAM2...")
85
-
86
- import torch
87
- if torch.cuda.is_available():
88
- torch.cuda.empty_cache()
89
- gc.collect()
90
-
91
- from sam2.build_sam import build_sam2_video_predictor
92
- logger.info("SAM2 imported successfully")
93
- return build_sam2_video_predictor
94
-
95
- except ImportError as e:
96
- logger.error(f"SAM2 import failed: {e}")
97
- raise RuntimeError(f"SAM2 not available: {e}")
98
  except Exception as e:
99
- logger.error(f"Unexpected error importing SAM2: {e}")
100
  raise
101
 
102
- def lazy_import_matanyone():
103
- """Lazy import MatAnyone with graceful fallback"""
 
 
104
  try:
105
- logger.info("Attempting MatAnyone import...")
106
- from models.matanyone_loader import MatAnyoneLoader
107
- logger.info("MatAnyone imported successfully")
108
- return MatAnyoneLoader
 
 
 
 
 
 
 
 
 
 
109
  except Exception as e:
110
- logger.warning(f"MatAnyone not available: {e}")
111
  return None
112
 
113
- # ================================================================================================
114
- # MEMORY MANAGEMENT UTILITIES
115
- # ================================================================================================
116
-
117
- def clear_gpu_memory():
118
- """Clear GPU memory cache"""
 
 
119
  try:
120
- import torch
121
- if torch.cuda.is_available():
122
- torch.cuda.empty_cache()
123
- torch.cuda.synchronize()
124
- gc.collect()
 
 
125
  except Exception as e:
126
- logger.warning(f"GPU memory clear failed: {e}")
127
-
128
- def log_memory_usage(stage: str):
129
- """Log current memory usage"""
130
- try:
131
- import torch
132
- if torch.cuda.is_available():
133
- allocated = torch.cuda.memory_allocated() / (1024**3)
134
- reserved = torch.cuda.memory_reserved() / (1024**3)
135
- logger.info(f"{stage} - GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
136
- except Exception:
137
- pass
138
 
139
- # ================================================================================================
140
- # SAFE FILE OPERATIONS
141
- # ================================================================================================
 
 
 
 
 
142
 
143
- def safe_tmp_path(base_dir: str, extension: str) -> Path:
144
- """Generate safe temporary file path"""
145
- timestamp = int(time.time())
146
- random_id = uuid.uuid4().hex[:8]
147
- filename = f"tmp_{timestamp}_{random_id}{extension}"
148
- return Path(base_dir) / filename
 
 
 
 
 
 
 
 
 
149
 
150
- def safe_video_writer(output_path: Path, fourcc_str: str, fps: float, size: Tuple[int, int]):
151
- """Create video writer with error handling"""
 
 
 
 
 
 
 
 
 
 
152
  try:
153
- fourcc = cv2.VideoWriter_fourcc(*fourcc_str)
154
- writer = cv2.VideoWriter(str(output_path), fourcc, fps, size)
155
-
156
- if not writer.isOpened():
157
- raise RuntimeError(f"Failed to open video writer: {output_path}")
158
-
159
- return writer
160
- except Exception as e:
161
- logger.error(f"Video writer creation failed: {e}")
162
- raise
163
-
164
- # ================================================================================================
165
- # CHECKPOINT DOWNLOAD
166
- # ================================================================================================
167
-
168
- def download_sam2_checkpoint(checkpoint_path: str, work_dir: str = None, timeout_seconds: int = 600):
169
- """Download SAM2 checkpoint with timeout protection"""
170
- checkpoint_file = Path(checkpoint_path)
171
-
172
- if checkpoint_file.exists():
173
- logger.info(f"SAM2 checkpoint already exists: {checkpoint_file}")
174
  return True
175
-
176
- try:
177
- logger.info("SAM2 checkpoint not found, downloading...")
178
-
179
- checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
180
-
181
- import requests
182
-
183
- checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
184
-
185
- logger.info(f"Downloading from: {checkpoint_url}")
186
- logger.info(f"Target: {checkpoint_file}")
187
-
188
- start_time = time.time()
189
- response = requests.get(checkpoint_url, stream=True, timeout=30)
190
- response.raise_for_status()
191
-
192
- total_size = int(response.headers.get('content-length', 0))
193
- logger.info(f"File size: {total_size / (1024**2):.1f}MB")
194
-
195
- # Download to temporary file first
196
- work_path = Path(work_dir) if work_dir else checkpoint_file.parent
197
- temp_download = safe_tmp_path(str(work_path), ".pt.download")
198
-
199
- downloaded = 0
200
- last_log_time = start_time
201
-
202
- try:
203
- with open(temp_download, 'wb') as f:
204
- for chunk in response.iter_content(chunk_size=1024*1024):
205
- if chunk:
206
- f.write(chunk)
207
- downloaded += len(chunk)
208
-
209
- current_time = time.time()
210
- elapsed = current_time - start_time
211
-
212
- # Timeout check
213
- if elapsed > timeout_seconds:
214
- raise TimeoutError(f"Download timeout after {elapsed:.1f}s")
215
-
216
- # Progress logging every 15 seconds
217
- if current_time - last_log_time > 15:
218
- progress = (downloaded / total_size * 100) if total_size > 0 else 0
219
- speed = downloaded / elapsed / (1024**2) # MB/s
220
- logger.info(f"Download: {progress:.1f}% ({speed:.1f}MB/s)")
221
- last_log_time = current_time
222
-
223
- # Verify download
224
- if total_size > 0 and downloaded != total_size:
225
- raise RuntimeError(f"Incomplete download: {downloaded}/{total_size} bytes")
226
-
227
- # Move to final location
228
- temp_download.replace(checkpoint_file)
229
-
230
- total_time = time.time() - start_time
231
- speed = downloaded / total_time / (1024**2)
232
- logger.info(f"Download complete: {downloaded / (1024**2):.1f}MB in {total_time:.1f}s ({speed:.1f}MB/s)")
233
-
234
- return True
235
-
236
- except Exception as download_error:
237
- if temp_download.exists():
238
- temp_download.unlink()
239
- raise download_error
240
-
241
  except Exception as e:
242
- logger.error(f"Download failed: {e}")
243
- if checkpoint_file.exists():
244
- try:
245
- checkpoint_file.unlink()
246
- except Exception:
247
- pass
248
  return False
249
 
250
- # ================================================================================================
251
- # MAIN PROCESSING FUNCTION
252
- # ================================================================================================
253
-
254
  def process(
255
  video_path: str,
256
  background_image: Optional[Image.Image] = None,
@@ -260,371 +249,229 @@ def process(
260
  progress_callback: Optional[Callable[[str, float], None]] = None
261
  ) -> str:
262
  """
263
- Process video with background replacement using SAM2 + MatAnyone
264
-
265
- Args:
266
- video_path: Path to input video
267
- background_image: PIL Image for background (if background_type is custom)
268
- background_type: Type of background ("custom", "gradient", "solid", etc.)
269
- background_prompt: Prompt for background generation
270
- job_directory: Directory for processing files
271
- progress_callback: Optional callback for progress updates
272
-
273
- Returns:
274
- Path to processed video file
275
  """
276
-
277
- def log_progress(step: str, progress: float = None):
278
- if progress is not None:
279
- logger.info(f"Progress {progress:.1%}: {step}")
 
 
 
 
 
 
 
280
  else:
281
- logger.info(f"Step: {step}")
282
  if progress_callback:
283
  try:
284
- progress_callback(step, progress)
285
  except Exception as e:
286
- logger.warning(f"Progress callback error: {e}")
287
-
288
- # Set up job directory
 
 
 
 
 
289
  if job_directory is None:
290
  job_directory = Path.cwd() / "tmp" / f"job_{uuid.uuid4().hex[:8]}"
291
-
292
  job_directory.mkdir(parents=True, exist_ok=True)
293
- logger.info(f"Processing in job directory: {job_directory}")
294
-
295
- start_time = time.time()
296
-
297
- try:
298
- # ============================================================================================
299
- # STAGE 1: ENVIRONMENT VALIDATION
300
- # ============================================================================================
301
-
302
- log_progress("Validating PyTorch environment", 0.02)
303
- device, torch_version = validate_pytorch_environment()
304
- log_memory_usage("Environment validated")
305
-
306
- # ============================================================================================
307
- # STAGE 2: VIDEO ANALYSIS
308
- # ============================================================================================
309
-
310
- log_progress("Analyzing input video", 0.05)
311
-
312
- video_file = Path(video_path)
313
- if not video_file.exists():
314
- raise FileNotFoundError(f"Video file not found: {video_path}")
315
-
316
- # Copy video to job directory for safe processing
317
- safe_video_path = job_directory / f"input{video_file.suffix}"
318
- if safe_video_path != video_file:
319
- logger.info(f"Copying video to job directory: {safe_video_path}")
320
- shutil.copy2(video_path, safe_video_path)
321
- video_path = str(safe_video_path)
322
-
323
- # Get video properties
324
- cap = cv2.VideoCapture(video_path)
325
- if not cap.isOpened():
326
- raise RuntimeError(f"Cannot open video: {video_path}")
327
-
328
- fps = cap.get(cv2.CAP_PROP_FPS)
329
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
330
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
331
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
332
- duration = frame_count / fps if fps > 0 else 0
333
- cap.release()
334
-
335
- logger.info(f"Video: {width}x{height} @ {fps:.1f}fps, {frame_count} frames ({duration:.1f}s)")
336
-
337
- # ============================================================================================
338
- # STAGE 3: BACKGROUND PREPARATION
339
- # ============================================================================================
340
-
341
- log_progress("Preparing background", 0.08)
342
-
343
- if background_image is None:
344
- raise ValueError("Background image is required")
345
-
346
- # Resize background to match video
347
- bg_image = background_image.resize((width, height), Image.LANCZOS)
348
- bg_array = np.array(bg_image)
349
- logger.info(f"Background prepared: {bg_image.size}")
350
-
351
- # ============================================================================================
352
- # STAGE 4: SAM2 MODEL LOADING
353
- # ============================================================================================
354
-
355
- log_progress("Loading SAM2 model", 0.1)
356
-
357
- # Download checkpoint
358
- sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
359
- if not download_sam2_checkpoint(sam2_checkpoint, str(job_directory)):
360
- raise RuntimeError("Failed to download SAM2 checkpoint")
361
-
362
- # Import and load SAM2
363
- build_sam2_video_predictor = lazy_import_sam2()
364
- clear_gpu_memory()
365
-
366
- model_cfg = "sam2_hiera_l.yaml"
367
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)
368
- logger.info("SAM2 model loaded successfully")
369
- log_memory_usage("SAM2 loaded")
370
-
371
- # ============================================================================================
372
- # STAGE 5: VIDEO PROCESSING INITIALIZATION
373
- # ============================================================================================
374
-
375
- log_progress("Initializing video processing", 0.2)
376
-
377
- inference_state = predictor.init_state(video_path=video_path)
378
-
379
- # Add prompt for person detection (center of frame)
380
- ann_frame_idx = 0
381
- ann_obj_id = 1
382
- points = np.array([[width//2, height//2]], dtype=np.float32)
383
- labels = np.array([1], np.int32)
384
-
385
- _, out_obj_ids, out_mask_logits = predictor.add_new_points(
386
- inference_state=inference_state,
387
- frame_idx=ann_frame_idx,
388
  obj_id=ann_obj_id,
389
- points=points,
390
  labels=labels,
391
  )
392
-
393
- logger.info("Video processing initialized with person detection prompt")
394
-
395
- # ============================================================================================
396
- # STAGE 6: CHUNKED MASK GENERATION WITH FULL CACHE CLEARANCE
397
- # ============================================================================================
398
-
399
- log_progress("Generating masks with chunked SAM2 processing", 0.3)
400
-
401
- # Calculate optimal chunk size based on available memory
402
- available_memory_gb = 12.0 # Conservative for T4
403
- estimated_memory_per_frame = 0.05 # ~50MB per frame for 720p
404
- max_chunk_size = min(200, int(available_memory_gb / estimated_memory_per_frame))
405
- chunk_size = max(50, max_chunk_size) # Minimum 50 frames, maximum based on memory
406
-
407
- logger.info(f"Using chunk size: {chunk_size} frames for {frame_count} total frames")
408
-
409
- video_segments = {}
410
- frames_processed = 0
411
-
412
- # Process video in chunks to prevent memory overflow
413
- for chunk_start in range(0, frame_count, chunk_size):
414
- chunk_end = min(chunk_start + chunk_size, frame_count)
415
- chunk_frames = chunk_end - chunk_start
416
-
417
- logger.info(f"Processing chunk: frames {chunk_start}-{chunk_end} ({chunk_frames} frames)")
418
-
419
- # Clear all GPU memory before each chunk
420
- clear_gpu_memory()
421
- log_memory_usage(f"Before chunk {chunk_start//chunk_size + 1}")
422
-
423
- try:
424
- # Create fresh inference state for this chunk
425
- chunk_inference_state = predictor.init_state(video_path=video_path)
426
-
427
- # Add prompt for this chunk (re-add for each chunk)
428
- _, out_obj_ids, out_mask_logits = predictor.add_new_points(
429
- inference_state=chunk_inference_state,
430
- frame_idx=chunk_start, # Use chunk start as reference frame
431
- obj_id=ann_obj_id,
432
- points=points,
433
- labels=labels,
434
- )
435
-
436
- # Process only frames in this chunk
437
- chunk_segments = {}
438
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
439
- chunk_inference_state,
440
- start_frame_idx=chunk_start,
441
- max_frame_idx=chunk_end - 1
442
- ):
443
- if chunk_start <= out_frame_idx < chunk_end:
444
- # Immediately move masks to CPU and store
445
- frame_masks = {}
446
- for i, out_obj_id in enumerate(out_obj_ids):
447
- mask = (out_mask_logits[i] > 0.0).cpu().numpy()
448
- frame_masks[out_obj_id] = mask
449
-
450
- video_segments[out_frame_idx] = frame_masks
451
- chunk_segments[out_frame_idx] = frame_masks
452
- frames_processed += 1
453
-
454
- logger.info(f"Chunk {chunk_start//chunk_size + 1} complete: {len(chunk_segments)} masks generated")
455
-
456
- # Aggressive cleanup after each chunk
457
- del chunk_inference_state
458
- del chunk_segments
459
- clear_gpu_memory()
460
-
461
- # Progress update
462
- progress = 0.3 + (frames_processed / frame_count) * 0.4
463
- log_progress(f"Processed {frames_processed}/{frame_count} frames in chunks", progress)
464
-
465
- except Exception as e:
466
- logger.error(f"Chunk {chunk_start//chunk_size + 1} failed: {e}")
467
- # Try to continue with next chunk rather than failing completely
468
- clear_gpu_memory()
469
- continue
470
-
471
- logger.info(f"Chunked processing complete: {len(video_segments)} total masks generated")
472
- log_memory_usage("All chunks processed")
473
-
474
- # ============================================================================================
475
- # STAGE 7: COMPLETE SAM2 MODEL AND INFERENCE STATE CLEANUP
476
- # ============================================================================================
477
-
478
- log_progress("Complete SAM2 cleanup and memory reclaim", 0.72)
479
-
480
- try:
481
- # Delete all SAM2 references
482
- del predictor
483
- if 'inference_state' in locals():
484
- del inference_state
485
-
486
- # Remove SAM2 from Python modules
487
- import sys
488
- sam2_modules = [name for name in sys.modules.keys() if 'sam2' in name.lower()]
489
- logger.info(f"Removing {len(sam2_modules)} SAM2 modules from memory")
490
- for module_name in sam2_modules:
491
- try:
492
- del sys.modules[module_name]
493
- except Exception:
494
- pass
495
-
496
- # Force Python garbage collection
497
- import gc
498
- collected = gc.collect()
499
- logger.info(f"Garbage collected {collected} objects")
500
-
501
- # Final aggressive GPU cleanup
502
- import torch
503
- if torch.cuda.is_available():
504
- torch.cuda.empty_cache()
505
- torch.cuda.synchronize()
506
- # Reset memory stats
507
- torch.cuda.reset_peak_memory_stats()
508
-
509
- log_memory_usage("SAM2 completely removed")
510
-
511
- except Exception as e:
512
- logger.warning(f"SAM2 cleanup warning: {e}")
513
-
514
- # ============================================================================================
515
- # STAGE 8: MEMORY-EFFICIENT VIDEO COMPOSITION
516
- # ============================================================================================
517
-
518
- log_progress("Video composition with memory management", 0.8)
519
-
520
- output_path = job_directory / f"output_{int(time.time())}.mp4"
521
- out_writer = safe_video_writer(output_path, 'mp4v', fps, (width, height))
522
-
523
- cap = cv2.VideoCapture(video_path)
524
- frame_idx = 0
525
- composition_chunk_size = 50 # Smaller chunks for composition
526
-
527
- try:
528
- frames_batch = []
529
-
530
- while True:
531
- ret, frame = cap.read()
532
  if not ret:
533
  break
534
-
535
- # Process frame
536
- if frame_idx in video_segments and ann_obj_id in video_segments[frame_idx]:
537
- mask = video_segments[frame_idx][ann_obj_id]
538
- mask_3ch = np.stack([mask, mask, mask], axis=2)
539
-
540
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
541
- composite = frame_rgb * mask_3ch + bg_array * (1 - mask_3ch)
542
- composite_bgr = cv2.cvtColor(composite.astype(np.uint8), cv2.COLOR_RGB2BGR)
543
- out_writer.write(composite_bgr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  else:
545
- out_writer.write(frame)
546
-
547
- frame_idx += 1
548
-
549
- # Memory cleanup every batch
550
- if frame_idx % composition_chunk_size == 0:
551
- # Clear processed masks from memory to save RAM
552
- for i in range(max(0, frame_idx - composition_chunk_size), frame_idx):
553
- if i in video_segments:
554
- del video_segments[i]
555
-
556
- clear_gpu_memory()
557
- progress = 0.8 + (frame_idx / frame_count) * 0.15
558
- log_progress(f"Compositing {frame_idx}/{frame_count} (memory managed)", progress)
559
-
560
- finally:
561
- cap.release()
562
- out_writer.release()
563
-
564
- # Final cleanup of remaining masks
565
- video_segments.clear()
566
- clear_gpu_memory()
567
-
568
- # ============================================================================================
569
- # STAGE 9: AUDIO RESTORATION
570
- # ============================================================================================
571
-
572
- log_progress("Adding audio track", 0.95)
573
-
574
- final_output = job_directory / f"final_with_audio_{int(time.time())}.mp4"
575
-
576
- try:
577
- cmd = [
578
- 'ffmpeg', '-y', '-hide_banner', '-loglevel', 'error',
579
- '-i', str(output_path), # Video input
580
- '-i', video_path, # Audio source
581
- '-c:v', 'copy', # Copy video
582
- '-c:a', 'aac', # Encode audio
583
- '-shortest', # Match shortest stream
584
- str(final_output)
585
- ]
586
-
587
- result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
588
-
589
- if result.returncode == 0:
590
- logger.info("Audio successfully added")
591
- output_path.unlink() # Remove temp video
592
- final_path = str(final_output)
593
- else:
594
- logger.warning(f"Audio processing failed: {result.stderr}")
595
- final_path = str(output_path)
596
-
597
- except Exception as e:
598
- logger.warning(f"Audio processing error: {e}")
599
- final_path = str(output_path)
600
-
601
- # ============================================================================================
602
- # STAGE 10: COMPLETION
603
- # ============================================================================================
604
-
605
- total_time = time.time() - start_time
606
- log_memory_usage("Processing complete")
607
-
608
- try:
609
- import torch
610
- if torch.cuda.is_available():
611
- peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
612
- logger.info(f"Peak GPU memory: {peak_memory:.2f}GB")
613
- except Exception:
614
- pass
615
-
616
- log_progress(f"Processing complete in {total_time:.1f}s", 1.0)
617
-
618
- logger.info(f"Output video: {final_path}")
619
- logger.info(f"Job directory: {job_directory}")
620
-
621
- return final_path
622
-
623
- except Exception as e:
624
- logger.error(f"Processing failed: {e}")
625
- logger.error(f"Job directory: {job_directory}")
626
- raise
627
-
628
- finally:
629
- # Final cleanup
630
- clear_gpu_memory()
 
1
  #!/usr/bin/env python3
2
  """
3
+ pipeline.py β€” Production SAM2 + MatAnyone (T4-optimized, single-pass streaming)
4
+
5
+ Key features
6
+ ------------
7
+ - One SAM2 inference state for the entire video (no per-chunk reinit).
8
+ - In-stream pipeline: Read β†’ SAM2 β†’ MatAnyone β†’ Compose β†’ Write (no big RAM dicts).
9
+ - Bounded memory everywhere (deque/window); optional CPU spill.
10
+ - fp16 + channels_last on SAM2; mixed precision blocks.
11
+ - VRAM-aware controller adjusts memory window/scale.
12
+ - Heartbeat logger to prevent HF watchdog restarts.
13
+ - Safer FFmpeg audio re-mux.
14
+
15
+ Compatible with Tesla T4 (β‰ˆ15–16 GB) and PyTorch 2.5.x + CUDA 12.4 wheels.
16
  """
17
 
18
  import os
19
+ import gc
20
  import cv2
21
  import time
22
  import uuid
23
+ import torch
24
+ import queue
25
  import shutil
26
+ import logging
27
  import tempfile
28
  import subprocess
29
+ import threading
30
  import numpy as np
31
  from PIL import Image
 
 
32
  from pathlib import Path
33
  from typing import Optional, Tuple, Dict, Any, Callable
34
+ from collections import deque
35
 
36
+ # ----------------------------------------------------------------------------------------------------------------------
37
+ # Logging
38
+ # ----------------------------------------------------------------------------------------------------------------------
39
+ logger = logging.getLogger("backgroundfx_pro")
40
+ if not logger.handlers:
41
+ h = logging.StreamHandler()
42
+ h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s: %(message)s"))
43
+ logger.addHandler(h)
44
+ logger.setLevel(logging.INFO)
45
 
46
+ # ----------------------------------------------------------------------------------------------------------------------
47
+ # Environment & Torch tuning for T4
48
+ # ----------------------------------------------------------------------------------------------------------------------
49
+ def setup_t4_environment():
50
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF",
51
+ "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7")
52
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
53
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
54
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
55
+ os.environ.setdefault("OPENCV_OPENCL_RUNTIME", "disabled")
56
+ os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
57
 
58
+ torch.set_grad_enabled(False)
 
59
  try:
60
+ torch.backends.cudnn.benchmark = True
61
+ torch.backends.cuda.matmul.allow_tf32 = True
62
+ torch.backends.cudnn.allow_tf32 = True
63
+ torch.set_float32_matmul_precision("high")
64
+ except Exception:
65
+ pass
66
+
67
+ if torch.cuda.is_available():
 
 
 
 
 
68
  try:
69
+ frac = float(os.getenv("CUDA_MEMORY_FRACTION", "0.88"))
70
+ torch.cuda.set_per_process_memory_fraction(frac)
71
+ logger.info(f"CUDA per-process memory fraction = {frac:.2f}")
72
+ except Exception as e:
73
+ logger.warning(f"Could not set CUDA memory fraction: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ def vram_gb() -> Tuple[float, float]:
76
+ if not torch.cuda.is_available():
77
+ return 0.0, 0.0
78
+ free, total = torch.cuda.mem_get_info()
79
+ return free / (1024 ** 3), total / (1024 ** 3)
80
+
81
+ # ----------------------------------------------------------------------------------------------------------------------
82
+ # Heartbeat (prevents Spaces watchdog killing the job)
83
+ # ----------------------------------------------------------------------------------------------------------------------
84
+ def heartbeat_monitor(running_flag: Dict[str, bool], interval: float = 8.0):
85
+ while running_flag.get("running", False):
86
+ print(f"[HB] t={int(time.time())}", flush=True)
87
+ time.sleep(interval)
88
+
89
+ # ----------------------------------------------------------------------------------------------------------------------
90
+ # Streaming video I/O
91
+ # ----------------------------------------------------------------------------------------------------------------------
92
+ class StreamingVideoIO:
93
+ def __init__(self, video_path: str, out_path: str, fps: float):
94
+ self.video_path = video_path
95
+ self.out_path = out_path
96
+ self.fps = fps
97
+ self.cap = None
98
+ self.writer = None
99
+ self.size = None
100
+
101
+ def __enter__(self):
102
+ self.cap = cv2.VideoCapture(self.video_path)
103
+ if not self.cap.isOpened():
104
+ raise RuntimeError(f"Cannot open video: {self.video_path}")
105
+ w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
106
+ h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
107
+ self.size = (w, h)
108
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
109
+ self.writer = cv2.VideoWriter(self.out_path, fourcc, self.fps, (w, h))
110
+ return self
111
+
112
+ def __exit__(self, exc_type, exc_val, exc_tb):
113
+ if self.cap:
114
+ self.cap.release()
115
+ if self.writer:
116
+ self.writer.release()
117
+
118
+ def read_frame(self):
119
+ if not self.cap:
120
+ return False, None
121
+ return self.cap.read()
122
+
123
+ def write_frame(self, frame_bgr: np.ndarray):
124
+ if not self.writer:
125
+ return
126
+ self.writer.write(frame_bgr)
127
+
128
+ # ----------------------------------------------------------------------------------------------------------------------
129
+ # Models: loaders and safe optimizations
130
+ # ----------------------------------------------------------------------------------------------------------------------
131
+ def load_sam2_predictor(device: torch.device):
132
+ """
133
+ Prefer your local wrapper to keep interfaces stable.
134
+ """
135
  try:
136
+ from models.sam2_loader import SAM2Predictor # your wrapper
137
+ predictor = SAM2Predictor(device=device)
138
+ # Optional: try to access underlying model to set fp16 + channels_last
139
+ try:
140
+ if hasattr(predictor, "model") and predictor.model is not None:
141
+ predictor.model = predictor.model.half().to(device)
142
+ predictor.model = predictor.model.to(memory_format=torch.channels_last)
143
+ logger.info("SAM2: fp16 + channels_last applied (wrapper model).")
144
+ except Exception as e:
145
+ logger.warning(f"SAM2 fp16 optimization warning: {e}")
146
+ return predictor
 
 
 
147
  except Exception as e:
148
+ logger.error(f"Failed to import SAM2Predictor: {e}")
149
  raise
150
 
151
+ def load_matany_session(device: torch.device):
152
+ """
153
+ Supports either MatAnyoneSession or MatAnyoneLoader (your code has varied).
154
+ """
155
  try:
156
+ try:
157
+ from models.matanyone_loader import MatAnyoneSession as _MatAny
158
+ except Exception:
159
+ from models.matanyone_loader import MatAnyoneLoader as _MatAny
160
+ session = _MatAny(device=device)
161
+ # Try fp16 eval where safe
162
+ if hasattr(session, "model") and session.model is not None:
163
+ session.model.eval()
164
+ try:
165
+ session.model = session.model.half().to(device)
166
+ logger.info("MatAnyone: fp16 + eval applied.")
167
+ except Exception:
168
+ logger.info("MatAnyone: using fp32 (fp16 not supported for some layers).")
169
+ return session
170
  except Exception as e:
171
+ logger.warning(f"MatAnyone not available ({e}). Proceeding without refinement.")
172
  return None
173
 
174
+ # ----------------------------------------------------------------------------------------------------------------------
175
+ # SAM2 state pruning (adapter): we call predictor.prune_state if present, else best-effort
176
+ # ----------------------------------------------------------------------------------------------------------------------
177
+ def prune_sam2_state(predictor, state: Any, keep: int):
178
+ """
179
+ Try to prune SAM2 temporal caches to a fixed window length.
180
+ Your SAM2Predictor should implement prune_state(state, keep=N). If not, we do nothing.
181
+ """
182
  try:
183
+ if hasattr(predictor, "prune_state"):
184
+ predictor.prune_state(state, keep=keep)
185
+ elif hasattr(state, "prune") and callable(getattr(state, "prune")):
186
+ state.prune(keep=keep)
187
+ else:
188
+ # No-op; rely on model internals and GC
189
+ pass
190
  except Exception as e:
191
+ logger.debug(f"SAM2 prune_state warning: {e}")
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ # ----------------------------------------------------------------------------------------------------------------------
194
+ # VRAM-aware controller
195
+ # ----------------------------------------------------------------------------------------------------------------------
196
+ class VRAMAdaptiveController:
197
+ def __init__(self):
198
+ self.memory_window = int(os.getenv("SAM2_WINDOW", "96")) # frames to keep in model state
199
+ self.propagation_scale = float(os.getenv("SAM2_PROP_SCALE", "0.90")) # e.g., downscale factor for propagation
200
+ self.cleanup_every = 20 # frames
201
 
202
+ def adapt(self):
203
+ free, total = vram_gb()
204
+ if free == 0.0:
205
+ return
206
+ # Tighten if we dip under ~1.6 GB
207
+ if free < 1.6:
208
+ self.memory_window = max(48, self.memory_window - 8)
209
+ self.propagation_scale = max(0.75, self.propagation_scale - 0.03)
210
+ self.cleanup_every = max(12, self.cleanup_every - 2)
211
+ logger.warning(f"Low VRAM ({free:.2f} GB free) β†’ window={self.memory_window}, scale={self.propagation_scale:.2f}")
212
+ # Relax if plenty free
213
+ elif free > 3.0:
214
+ self.memory_window = min(128, self.memory_window + 4)
215
+ self.propagation_scale = min(1.0, self.propagation_scale + 0.01)
216
+ self.cleanup_every = min(40, self.cleanup_every + 2)
217
 
218
+ # ----------------------------------------------------------------------------------------------------------------------
219
+ # Audio mux helper (safer stream mapping)
220
+ # ----------------------------------------------------------------------------------------------------------------------
221
+ def mux_audio(video_path_no_audio: str, source_with_audio: str, out_path: str) -> bool:
222
+ cmd = [
223
+ "ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
224
+ "-i", video_path_no_audio,
225
+ "-i", source_with_audio,
226
+ "-map", "0:v:0", "-map", "1:a:0",
227
+ "-c:v", "copy", "-c:a", "aac", "-shortest",
228
+ out_path
229
+ ]
230
  try:
231
+ r = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
232
+ if r.returncode != 0:
233
+ logger.warning(f"FFmpeg mux failed: {r.stderr.strip()}")
234
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  except Exception as e:
237
+ logger.warning(f"FFmpeg mux error: {e}")
 
 
 
 
 
238
  return False
239
 
240
+ # ----------------------------------------------------------------------------------------------------------------------
241
+ # Main processing
242
+ # ----------------------------------------------------------------------------------------------------------------------
 
243
  def process(
244
  video_path: str,
245
  background_image: Optional[Image.Image] = None,
 
249
  progress_callback: Optional[Callable[[str, float], None]] = None
250
  ) -> str:
251
  """
252
+ Production SAM2 + MatAnyone pipeline for T4.
253
+ - Single-pass streaming (no large mask dicts)
254
+ - Bounded memory windows
 
 
 
 
 
 
 
 
 
255
  """
256
+ setup_t4_environment()
257
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
258
+
259
+ # Heartbeat
260
+ hb_flag = {"running": True}
261
+ hb_thread = threading.Thread(target=heartbeat_monitor, args=(hb_flag, 8.0), daemon=True)
262
+ hb_thread.start()
263
+
264
+ def report(step: str, p: Optional[float] = None):
265
+ if p is None:
266
+ logger.info(step)
267
  else:
268
+ logger.info(f"{step} [{p:.1%}]")
269
  if progress_callback:
270
  try:
271
+ progress_callback(step, p)
272
  except Exception as e:
273
+ logger.debug(f"progress_callback error: {e}")
274
+
275
+ # Validate I/O
276
+ src = Path(video_path)
277
+ if not src.exists():
278
+ hb_flag["running"] = False
279
+ raise FileNotFoundError(f"Video not found: {video_path}")
280
+
281
  if job_directory is None:
282
  job_directory = Path.cwd() / "tmp" / f"job_{uuid.uuid4().hex[:8]}"
 
283
  job_directory.mkdir(parents=True, exist_ok=True)
284
+
285
+ # Probe video
286
+ cap_probe = cv2.VideoCapture(str(src))
287
+ if not cap_probe.isOpened():
288
+ hb_flag["running"] = False
289
+ raise RuntimeError(f"Cannot open video: {video_path}")
290
+ fps = cap_probe.get(cv2.CAP_PROP_FPS) or 25.0
291
+ width = int(cap_probe.get(cv2.CAP_PROP_FRAME_WIDTH))
292
+ height = int(cap_probe.get(cv2.CAP_PROP_FRAME_HEIGHT))
293
+ frame_count = int(cap_probe.get(cv2.CAP_PROP_FRAME_COUNT))
294
+ duration = frame_count / fps if fps > 0 else 0.0
295
+ cap_probe.release()
296
+ logger.info(f"Video: {width}x{height} @ {fps:.2f} fps | {frame_count} frames ({duration:.1f}s)")
297
+
298
+ # Prepare background
299
+ if background_image is None:
300
+ hb_flag["running"] = False
301
+ raise ValueError("background_image is required")
302
+ bg = background_image.resize((width, height), Image.LANCZOS)
303
+ bg_np = np.array(bg).astype(np.float32)
304
+
305
+ # Load models
306
+ report("Loading SAM2 + MatAnyone", 0.05)
307
+ predictor = load_sam2_predictor(device)
308
+ matany = load_matany_session(device)
309
+
310
+ # Init SAM2 state (single)
311
+ report("Initializing SAM2 video state", 0.08)
312
+ state = predictor.init_state(video_path=str(src))
313
+
314
+ # Minimal prompt: single positive point at center (replace with your prompt UI if needed)
315
+ center_pt = np.array([[width // 2, height // 2]], dtype=np.float32)
316
+ labels = np.array([1], dtype=np.int32)
317
+ ann_obj_id = 1
318
+ with torch.inference_mode():
319
+ _ = predictor.add_new_points(
320
+ inference_state=state,
321
+ frame_idx=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  obj_id=ann_obj_id,
323
+ points=center_pt,
324
  labels=labels,
325
  )
326
+
327
+ # Controller
328
+ ctrl = VRAMAdaptiveController()
329
+
330
+ # Output paths
331
+ out_raw = str(job_directory / f"composite_{int(time.time())}.mp4")
332
+ out_final = str(job_directory / f"final_{int(time.time())}.mp4")
333
+
334
+ # Windows/buffers (bounded)
335
+ # For completeness we keep a tiny deque for any auxiliary temporal ops (e.g., matting history)
336
+ aux_window = deque(maxlen=max(32, min(96, ctrl.memory_window // 2)))
337
+
338
+ # Stream processing
339
+ start = time.time()
340
+ frames_done = 0
341
+ next_cleanup_at = ctrl.cleanup_every
342
+
343
+ report("Streaming: SAM2 β†’ MatAnyone β†’ Compose β†’ Write", 0.12)
344
+ with StreamingVideoIO(str(src), out_raw, fps) as vio:
345
+ # iterate SAM2 propagation alongside reading frames
346
+ with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16 if device.type == "cuda" else None):
347
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(state, scale=ctrl.propagation_scale):
348
+ # Read the matching frame
349
+ ret, frame_bgr = vio.read_frame()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  if not ret:
351
  break
352
+
353
+ # Get mask for ann_obj_id; keep on GPU as long as possible
354
+ mask_t = None
355
+ try:
356
+ if isinstance(out_obj_ids, torch.Tensor):
357
+ # find index where id == ann_obj_id
358
+ idxs = (out_obj_ids == ann_obj_id).nonzero(as_tuple=False)
359
+ if idxs.numel() > 0:
360
+ i = idxs[0].item()
361
+ logits = out_mask_logits[i]
362
+ else:
363
+ logits = None
364
+ else:
365
+ # list/array fallback
366
+ ids_list = list(out_obj_ids)
367
+ i = ids_list.index(ann_obj_id) if ann_obj_id in ids_list else -1
368
+ logits = out_mask_logits[i] if i >= 0 else None
369
+
370
+ if logits is not None:
371
+ # logits β†’ prob β†’ binary mask (threshold 0)
372
+ mask_t = (logits > 0).float() # HxW on CUDA fp16 β†’ fp32 float
373
+ except Exception as e:
374
+ logger.debug(f"Mask extraction warning @frame {out_frame_idx}: {e}")
375
+ mask_t = None
376
+
377
+ # Optional: MatAnyone refinement
378
+ if mask_t is not None and matany is not None:
379
+ try:
380
+ # MatAnyone APIs vary β€” try common forms
381
+ # Convert RGB because many mattors expect RGB
382
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
383
+ # Move frame to GPU only if your matting backend supports it
384
+ refined = None
385
+ if hasattr(matany, "refine_mask"):
386
+ refined = matany.refine_mask(frame_rgb, mask_t) # allow handler to decide device
387
+ elif hasattr(matany, "process_frame"):
388
+ refined = matany.process_frame(frame_rgb, mask_t)
389
+ if refined is not None:
390
+ # ensure float mask 0..1 on CUDA or CPU
391
+ if isinstance(refined, torch.Tensor):
392
+ mask_t = refined.float()
393
+ else:
394
+ # numpy β†’ torch
395
+ mask_t = torch.from_numpy(refined.astype(np.float32))
396
+ if device.type == "cuda":
397
+ mask_t = mask_t.to(device)
398
+ except Exception as e:
399
+ logger.debug(f"MatAnyone refinement failed (frame {out_frame_idx}): {e}")
400
+
401
+ # Compose and write (convert once, keep math sane)
402
+ if mask_t is not None:
403
+ # bring mask to CPU for np composition; keep as float [0,1]
404
+ mask_np = mask_t.detach().clamp(0, 1).to("cpu", non_blocking=True).float().numpy()
405
+ m3 = mask_np[..., None] # HxWx1
406
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
407
+ comp = frame_rgb * m3 + bg_np * (1.0 - m3)
408
+ comp_bgr = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_RGB2BGR)
409
+ vio.write_frame(comp_bgr)
410
  else:
411
+ # No mask β€” write original frame
412
+ vio.write_frame(frame_bgr)
413
+
414
+ # Periodic maintenance
415
+ frames_done += 1
416
+ if frames_done >= next_cleanup_at:
417
+ ctrl.adapt()
418
+ prune_sam2_state(predictor, state, keep=ctrl.memory_window)
419
+ # Clear small aux buffers
420
+ aux_window.clear()
421
+ if device.type == "cuda":
422
+ torch.cuda.ipc_collect()
423
+ torch.cuda.empty_cache()
424
+ next_cleanup_at = frames_done + ctrl.cleanup_every
425
+
426
+ # Progress
427
+ if frames_done % 25 == 0 and frame_count > 0:
428
+ p = 0.12 + 0.75 * (frames_done / frame_count)
429
+ report(f"Processing frame {frames_done}/{frame_count} | win={ctrl.memory_window} scale={ctrl.propagation_scale:.2f}", p)
430
+
431
+ # Audio mux
432
+ report("Restoring audio", 0.93)
433
+ ok = mux_audio(out_raw, str(src), out_final)
434
+ final_path = out_final if ok else out_raw
435
+
436
+ # Cleanup models/state promptly
437
+ try:
438
+ del predictor
439
+ del state
440
+ if matany is not None:
441
+ del matany
442
+ except Exception:
443
+ pass
444
+
445
+ if device.type == "cuda":
446
+ torch.cuda.ipc_collect()
447
+ torch.cuda.empty_cache()
448
+ gc.collect()
449
+
450
+ hb_flag["running"] = False
451
+ elapsed = time.time() - start
452
+ try:
453
+ peak = torch.cuda.max_memory_allocated() / (1024 ** 3) if device.type == "cuda" else 0.0
454
+ logger.info(f"Peak GPU memory: {peak:.2f} GB")
455
+ except Exception:
456
+ pass
457
+ report(f"Done in {elapsed:.1f}s", 1.0)
458
+ logger.info(f"Output: {final_path}")
459
+ logger.info(f"Artifacts: {job_directory}")
460
+ return final_path
461
+
462
+
463
+ # -------------------------------------------------------------------------------------------------
464
+ # CLI entry (optional)
465
+ # -------------------------------------------------------------------------------------------------
466
+ if __name__ == "__main__":
467
+ import argparse
468
+ parser = argparse.ArgumentParser(description="BackgroundFX Pro pipeline")
469
+ parser.add_argument("--video", required=True, help="Path to input video")
470
+ parser.add_argument("--background", required=True, help="Path to background image")
471
+ parser.add_argument("--outdir", default=None, help="Job directory (optional)")
472
+ args = parser.parse_args()
473
+
474
+ bg_img = Image.open(args.background).convert("RGB")
475
+ outdir = Path(args.outdir) if args.outdir else None
476
+ out_path = process(args.video, background_image=bg_img, job_directory=outdir)
477
+ print(out_path)