MogensR commited on
Commit
4a91cab
Β·
1 Parent(s): b21ba08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +544 -767
app.py CHANGED
@@ -1,11 +1,10 @@
1
  #!/usr/bin/env python3
2
  """
3
- BackgroundFX - Professional Video Background Replacement
4
- Combined Pipeline: SAM2 (segmentation) + MatAnyone (matting refinement)
5
- Optimized for HuggingFace Spaces T4 GPU (16GB VRAM)
6
  """
7
 
8
- import streamlit as st
9
  import cv2
10
  import numpy as np
11
  import tempfile
@@ -18,882 +17,660 @@
18
  import torch
19
  import time
20
  from pathlib import Path
21
- from dataclasses import dataclass
22
- from typing import Optional, Dict, Tuple
23
 
24
  # Configure logging
25
- logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
 
28
- # ============================================
29
- # GPU SETUP AND INITIALIZATION
30
- # ============================================
 
31
 
32
- def setup_gpu_environment():
33
- """Setup GPU environment optimized for T4"""
34
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
35
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- try:
38
- if torch.cuda.is_available():
39
- gpu_name = torch.cuda.get_device_name(0)
40
- gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
41
-
42
- logger.info(f"πŸš€ GPU Detected: {gpu_name} ({gpu_memory:.1f}GB)")
43
-
44
- # Initialize CUDA
45
- torch.cuda.init()
46
- torch.cuda.set_device(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # T4 optimizations
49
- torch.backends.cudnn.benchmark = True
50
- torch.backends.cudnn.deterministic = False
51
 
52
- # T4 doesn't support TF32
53
- if 'T4' in gpu_name:
54
- torch.backends.cuda.matmul.allow_tf32 = False
55
- torch.backends.cudnn.allow_tf32 = False
56
- else:
57
- torch.backends.cuda.matmul.allow_tf32 = True
58
- torch.backends.cudnn.allow_tf32 = True
 
59
 
60
- # Warm up
61
- dummy = torch.randn(256, 256, device='cuda')
62
- del dummy
63
- torch.cuda.empty_cache()
64
 
65
- return True, gpu_name, gpu_memory
66
- else:
67
- logger.warning("⚠️ CUDA not available - running in CPU mode")
68
- return False, None, 0
69
- except Exception as e:
70
- logger.error(f"GPU setup failed: {e}")
71
- return False, None, 0
72
-
73
- # Initialize GPU
74
- CUDA_AVAILABLE, GPU_NAME, GPU_MEMORY = setup_gpu_environment()
75
- DEVICE = 'cuda' if CUDA_AVAILABLE else 'cpu'
76
-
77
- # ============================================
78
- # DATA STRUCTURES
79
- # ============================================
80
-
81
- @dataclass
82
- class ProcessingResult:
83
- """Container for processing results"""
84
- alpha: np.ndarray # Final alpha matte
85
- sam2_mask: Optional[np.ndarray] = None # SAM2 coarse mask
86
- trimap: Optional[np.ndarray] = None # Generated trimap
87
- method: str = "unknown"
88
- processing_time: float = 0.0
89
-
90
- # ============================================
91
- # COMBINED SAM2 + MATANYONE PROCESSOR
92
- # ============================================
93
-
94
- class CombinedProcessor:
95
- """
96
- Combines SAM2 and MatAnyone for ultimate quality
97
- SAM2: Initial segmentation (find the person)
98
- MatAnyone: Alpha matting refinement (perfect edges)
99
- """
100
 
101
- def __init__(self):
102
- self.sam2_predictor = None
103
- self.matanyone_model = None
104
- self.sam2_loaded = False
105
- self.matanyone_loaded = False
106
- self.device = DEVICE
107
-
108
- # Temporal consistency
109
- self.previous_result = None
110
- self.frame_count = 0
111
-
112
- @st.cache_resource
113
- def load_sam2(_self):
114
- """Load SAM2 model for segmentation"""
115
  try:
116
- from sam2.build_sam import build_sam2
117
- from sam2.sam2_image_predictor import SAM2ImagePredictor
118
-
119
- # Model selection based on available VRAM
120
- if GPU_MEMORY >= 15:
121
- model_config = {
122
- 'name': 'base_plus',
123
- 'config': 'sam2_hiera_b+.yaml',
124
- 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt',
125
- 'size': 323
126
- }
127
- elif GPU_MEMORY >= 8:
128
- model_config = {
129
- 'name': 'small',
130
- 'config': 'sam2_hiera_s.yaml',
131
- 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt',
132
- 'size': 155
133
- }
134
- else:
135
- model_config = {
136
- 'name': 'tiny',
137
- 'config': 'sam2_hiera_t.yaml',
138
- 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt',
139
- 'size': 77
140
- }
141
-
142
- # Download model if needed
143
- cache_dir = Path("/tmp/sam2_models")
144
- cache_dir.mkdir(exist_ok=True)
145
- model_path = cache_dir / f"sam2_{model_config['name']}.pt"
146
-
147
- if not model_path.exists():
148
- with st.spinner(f"Downloading SAM2 {model_config['name']} ({model_config['size']}MB)..."):
149
- response = requests.get(model_config['url'], stream=True)
150
- total_size = int(response.headers.get('content-length', 0))
151
-
152
- progress_bar = st.progress(0)
153
- with open(model_path, 'wb') as f:
154
- downloaded = 0
155
- for chunk in response.iter_content(chunk_size=8192):
156
- f.write(chunk)
157
- downloaded += len(chunk)
158
- if total_size > 0:
159
- progress_bar.progress(downloaded / total_size)
160
- progress_bar.empty()
161
 
162
  # Build model
163
- sam2_model = build_sam2(
164
- config_file=model_config['config'],
165
- ckpt_path=str(model_path),
166
- device=_self.device
167
- )
168
 
169
- # Use half precision on T4
170
- if CUDA_AVAILABLE and 'T4' in GPU_NAME:
171
  sam2_model = sam2_model.half()
 
172
 
173
- predictor = SAM2ImagePredictor(sam2_model)
 
174
 
175
- logger.info(f"βœ… SAM2 {model_config['name']} loaded successfully")
176
- return predictor, True
 
 
 
177
 
178
  except Exception as e:
179
- logger.error(f"❌ SAM2 loading failed: {e}")
180
- return None, False
 
 
 
 
 
 
 
 
181
 
182
- @st.cache_resource
183
- def load_matanyone(_self):
184
- """Load MatAnyone model for edge refinement"""
 
185
  try:
186
- from matanyone import MatAnyoneModel, MatAnyonePredictor
187
-
188
- # Download model if needed
189
- cache_dir = Path("/tmp/matanyone_models")
190
- cache_dir.mkdir(exist_ok=True)
191
- model_path = cache_dir / "matanyone_video.pth"
192
-
193
- if not model_path.exists():
194
- model_url = "https://huggingface.co/matanyone/matanyone-video/resolve/main/model.pth"
195
-
196
- with st.spinner("Downloading MatAnyone model..."):
197
- response = requests.get(model_url, stream=True)
198
- total_size = int(response.headers.get('content-length', 0))
199
-
200
- progress_bar = st.progress(0)
201
- with open(model_path, 'wb') as f:
202
- downloaded = 0
203
- for chunk in response.iter_content(chunk_size=8192):
204
- f.write(chunk)
205
- downloaded += len(chunk)
206
- if total_size > 0:
207
- progress_bar.progress(downloaded / total_size)
208
- progress_bar.empty()
209
-
210
- # Load model
211
- model = MatAnyoneModel.from_pretrained(
212
- str(model_path),
213
- device=_self.device,
214
- fp16=(CUDA_AVAILABLE) # Use FP16 on GPU
215
  )
216
 
217
- # Create predictor
218
- predictor = MatAnyonePredictor(
219
- model,
220
- enable_temporal=True,
221
- enable_refinement=True,
222
- alpha_quality='high'
223
- )
 
 
224
 
225
- logger.info("βœ… MatAnyone loaded successfully")
226
- return predictor, True
227
 
228
  except Exception as e:
229
- logger.warning(f"⚠️ MatAnyone not available: {e}")
230
- return None, False
 
 
 
 
 
 
 
 
 
231
 
232
- def initialize(self):
233
- """Initialize both models"""
234
- if not self.sam2_loaded:
235
- self.sam2_predictor, self.sam2_loaded = self.load_sam2()
236
-
237
- if not self.matanyone_loaded:
238
- self.matanyone_model, self.matanyone_loaded = self.load_matanyone()
239
-
240
- return self.sam2_loaded # At minimum need SAM2
241
 
242
- def process_frame(self, frame: np.ndarray, use_temporal: bool = True) -> ProcessingResult:
243
- """
244
- Process single frame using SAM2 + MatAnyone combined
245
-
246
- Pipeline:
247
- 1. SAM2 generates initial segmentation
248
- 2. Create trimap from SAM2 mask
249
- 3. MatAnyone refines using trimap
250
- 4. Return high-quality alpha matte
251
- """
252
-
253
- start_time = time.time()
254
-
255
- if not self.initialize():
256
- return None
257
-
258
- h, w = frame.shape[:2]
259
-
260
- # ============================================
261
- # STEP 1: SAM2 SEGMENTATION
262
- # ============================================
263
-
264
- # Set image for SAM2
265
- self.sam2_predictor.set_image(frame)
266
-
267
- # Generate point prompts with temporal consistency
268
- if use_temporal and self.previous_result and self.previous_result.sam2_mask is not None:
269
- # Use previous mask center
270
- prev_mask = self.previous_result.sam2_mask
271
- y_coords, x_coords = np.where(prev_mask > 0.5)
272
-
273
- if len(y_coords) > 0:
274
- center_y = int(np.mean(y_coords))
275
- center_x = int(np.mean(x_coords))
276
- # Focused points around previous center
277
- point_coords = np.array([
278
- [center_x, center_y],
279
- [center_x - w//40, center_y],
280
- [center_x + w//40, center_y],
281
- [center_x, center_y - h//40],
282
- [center_x, center_y + h//40]
283
- ])
284
- else:
285
- point_coords = self._get_default_points(w, h)
286
- else:
287
- point_coords = self._get_default_points(w, h)
288
-
289
- point_labels = np.ones(len(point_coords))
290
-
291
- # Get SAM2 predictions
292
- masks, scores, logits = self.sam2_predictor.predict(
293
- point_coords=point_coords,
294
- point_labels=point_labels,
295
- multimask_output=True,
296
- return_logits=True
297
- )
298
-
299
- # Select best mask
300
- best_idx = np.argmax(scores)
301
- sam2_mask = masks[best_idx].astype(np.float32)
302
-
303
- # Apply temporal smoothing to SAM2 mask
304
- if use_temporal and self.previous_result and self.previous_result.sam2_mask is not None:
305
- sam2_mask = 0.7 * sam2_mask + 0.3 * self.previous_result.sam2_mask
306
- sam2_mask = np.clip(sam2_mask, 0, 1)
307
-
308
- # ============================================
309
- # STEP 2: CREATE TRIMAP FROM SAM2 MASK
310
- # ============================================
311
-
312
- trimap = self._create_trimap_from_mask(sam2_mask)
313
-
314
- # ============================================
315
- # STEP 3: MATANYONE REFINEMENT (if available)
316
- # ============================================
317
-
318
- if self.matanyone_loaded and self.matanyone_model:
319
- try:
320
- # Use MatAnyone for refinement
321
- refined_alpha = self.matanyone_model.predict(
322
- image=frame,
323
- trimap=trimap,
324
- previous_alpha=self.previous_result.alpha if use_temporal and self.previous_result else None,
325
- temporal_weight=0.3 if use_temporal else 0.0
326
- )
327
-
328
- # Additional refinement with guided filter
329
- refined_alpha = cv2.ximgproc.guidedFilter(
330
- guide=frame,
331
- src=refined_alpha,
332
- radius=3,
333
- eps=1e-4
334
- )
335
-
336
- method = "SAM2+MatAnyone"
337
-
338
- except Exception as e:
339
- logger.warning(f"MatAnyone refinement failed, using SAM2 only: {e}")
340
- refined_alpha = sam2_mask
341
- method = "SAM2"
342
- else:
343
- # Use SAM2 mask with basic refinement
344
- refined_alpha = sam2_mask
345
-
346
- # Basic morphological refinement
347
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
348
- refined_alpha = cv2.morphologyEx(refined_alpha, cv2.MORPH_CLOSE, kernel)
349
- refined_alpha = cv2.GaussianBlur(refined_alpha, (5, 5), 0)
350
-
351
- method = "SAM2"
352
-
353
- # ============================================
354
- # STEP 4: FINAL POST-PROCESSING
355
- # ============================================
356
-
357
- # Ensure valid range
358
- refined_alpha = np.clip(refined_alpha, 0, 1)
359
-
360
- # Create result
361
- result = ProcessingResult(
362
- alpha=refined_alpha,
363
- sam2_mask=sam2_mask,
364
- trimap=trimap,
365
- method=method,
366
- processing_time=time.time() - start_time
367
- )
368
-
369
- # Store for temporal consistency
370
- self.previous_result = result
371
- self.frame_count += 1
372
 
373
- return result
374
-
375
- def _get_default_points(self, w: int, h: int) -> np.ndarray:
376
- """Get default point prompts for initial detection"""
377
- return np.array([
378
- [w//2, h//2], # Center
379
- [w//2, h//3], # Head area
380
- [w//2, 2*h//3], # Body area
381
- [w//3, h//2], # Left
382
- [2*w//3, h//2], # Right
383
- [w//2, h//4], # Upper
384
- [w//2, 3*h//4] # Lower
385
- ])
386
-
387
- def _create_trimap_from_mask(self, mask: np.ndarray, unknown_width: int = 20) -> np.ndarray:
388
- """
389
- Convert SAM2 mask to trimap for MatAnyone
390
- 0: Background, 128: Unknown, 255: Foreground
391
- """
392
- trimap = np.zeros_like(mask, dtype=np.uint8)
393
 
394
- # Threshold mask
395
- binary_mask = (mask > 0.5).astype(np.uint8)
396
 
397
- # Erode for definite foreground
398
- kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
399
- foreground = cv2.erode(binary_mask, kernel_small, iterations=2)
400
 
401
- # Dilate for potential foreground
402
- kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (unknown_width, unknown_width))
403
- potential_fg = cv2.dilate(binary_mask, kernel_large, iterations=2)
404
 
405
- # Create trimap
406
- trimap[potential_fg == 0] = 0 # Background
407
- trimap[foreground == 1] = 255 # Foreground
408
- trimap[(potential_fg == 1) & (foreground == 0)] = 128 # Unknown
409
 
410
- return trimap
411
-
412
- def reset(self):
413
- """Reset temporal state for new video"""
414
- self.previous_result = None
415
- self.frame_count = 0
416
- logger.info("Processor reset for new video")
417
-
418
- # ============================================
419
- # FALLBACK: REMBG PROCESSOR
420
- # ============================================
421
-
422
- REMBG_AVAILABLE = False
423
- rembg_session = None
424
-
425
- try:
426
- from rembg import remove, new_session
427
-
428
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if CUDA_AVAILABLE else ['CPUExecutionProvider']
429
- rembg_session = new_session('u2net_human_seg', providers=providers)
430
-
431
- # Warm up
432
- dummy_img = Image.new('RGB', (128, 128), color='white')
433
- _ = remove(dummy_img, session=rembg_session)
434
-
435
- REMBG_AVAILABLE = True
436
- logger.info("βœ… Rembg initialized as fallback")
437
-
438
- except Exception as e:
439
- logger.warning(f"⚠️ Rembg not available: {e}")
440
-
441
- def segment_with_rembg(frame):
442
- """Fallback segmentation using Rembg"""
443
- if not REMBG_AVAILABLE:
444
- return None
445
 
446
- try:
447
- pil_image = Image.fromarray(frame)
448
- output = remove(pil_image, session=rembg_session)
449
 
450
- output_array = np.array(output)
451
- if output_array.shape[2] == 4:
452
- return output_array[:, :, 3].astype(np.float32) / 255.0
453
- return None
454
  except Exception as e:
455
- logger.error(f"Rembg failed: {e}")
456
- return None
457
-
458
- # ============================================
459
- # BACKGROUND UTILITIES
460
- # ============================================
461
 
 
462
  def create_gradient_background(width=1280, height=720, color1=(70, 130, 180), color2=(255, 140, 90)):
463
- """Create gradient background"""
464
  background = np.zeros((height, width, 3), dtype=np.uint8)
465
-
466
  for y in range(height):
467
  ratio = y / height
 
468
  r = int(color1[0] * (1 - ratio) + color2[0] * ratio)
469
- g = int(color1[1] * (1 - ratio) + color2[1] * ratio)
470
  b = int(color1[2] * (1 - ratio) + color2[2] * ratio)
471
  background[y, :] = [r, g, b]
472
-
473
  return background
474
 
475
- def load_background_image(background_option):
476
- """Load or create background based on option"""
477
- if background_option.startswith("gradient:"):
478
- gradient_type = background_option.split(":")[1]
479
- if gradient_type == "blue":
480
- return create_gradient_background(color1=(70, 130, 180), color2=(135, 206, 235))
481
- elif gradient_type == "sunset":
482
- return create_gradient_background(color1=(255, 94, 77), color2=(255, 154, 0))
483
- else: # ocean
484
- return create_gradient_background(color1=(0, 119, 190), color2=(0, 180, 216))
485
- elif background_option.startswith("color:"):
486
- color_name = background_option.split(":")[1]
487
- colors = {"green": [0, 255, 0], "blue": [0, 0, 255], "white": [255, 255, 255]}
488
- background = np.full((720, 1280, 3), colors.get(color_name, [255, 255, 255]), dtype=np.uint8)
489
- return background
490
- else:
491
- try:
492
- response = requests.get(background_option, timeout=10)
493
- response.raise_for_status()
494
- image = Image.open(BytesIO(response.content))
495
- return np.array(image.convert('RGB'))
496
- except:
497
- return create_gradient_background()
498
-
499
- def get_background_options():
500
- """Background options for quick selection"""
501
  return {
502
- "πŸŒ… Blue Gradient": "gradient:blue",
503
- "πŸŒ‡ Sunset Gradient": "gradient:sunset",
504
- "🌊 Ocean Gradient": "gradient:ocean",
505
- "πŸ’š Green Screen": "color:green",
506
- "πŸ’™ Blue Screen": "color:blue",
507
- "βšͺ White Background": "color:white",
508
- "🏒 Office": "https://images.unsplash.com/photo-1497366216548-37526070297c?w=1280&h=720&fit=crop",
509
- "πŸŒ† City": "https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=1280&h=720&fit=crop",
510
- "πŸ–οΈ Beach": "https://images.unsplash.com/photo-1507525428034-b723cf961d3e?w=1280&h=720&fit=crop",
511
- "🌲 Nature": "https://images.unsplash.com/photo-1441974231531-c6227db76b6e?w=1280&h=720&fit=crop"
512
  }
513
 
514
- # ============================================
515
- # VIDEO PROCESSING PIPELINE
516
- # ============================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
 
518
- # Initialize processor globally
519
- processor = CombinedProcessor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
- def process_video(video_path, background_option, speed_mode='balanced', progress_callback=None):
522
- """
523
- Process video with SAM2 + MatAnyone combined pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
- Args:
526
- video_path: Input video path
527
- background_option: Background type/URL
528
- speed_mode: 'ultra_fast', 'fast', 'balanced', 'quality'
529
- progress_callback: Progress update function
530
- """
531
  try:
532
- # Load background
533
- background_image = load_background_image(background_option)
 
534
 
535
- # Open video
536
- cap = cv2.VideoCapture(video_path)
537
  fps = int(cap.get(cv2.CAP_PROP_FPS))
538
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
539
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
540
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
541
 
542
- logger.info(f"Processing video: {width}x{height}, {total_frames} frames, {fps} FPS")
543
-
544
- # Determine frame skip based on speed mode
545
- if speed_mode == 'ultra_fast':
546
- frame_skip = 3 # Process every 3rd frame
547
- interpolate = True
548
- elif speed_mode == 'fast':
549
- frame_skip = 2 # Process every 2nd frame
550
- interpolate = True
551
- elif speed_mode == 'balanced':
552
- frame_skip = 1 # Process all frames
553
- interpolate = False
554
- else: # quality
555
- frame_skip = 1
556
- interpolate = False
557
-
558
- # Create output
559
  output_path = tempfile.mktemp(suffix='.mp4')
560
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
561
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
562
 
563
- # Resize background once
564
- background_resized = cv2.resize(background_image, (width, height))
565
-
566
- # Reset processor for new video
567
- processor.reset()
568
 
 
569
  frame_count = 0
570
- processed_count = 0
571
- processing_times = []
572
- last_alpha = None
573
 
 
 
 
 
 
 
 
574
  while True:
575
  ret, frame = cap.read()
576
  if not ret:
577
  break
578
 
579
- # Convert BGR to RGB
580
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
581
 
582
- # Process frame or use interpolation
583
- if frame_count % frame_skip == 0:
584
- start_time = time.time()
585
-
586
- # Process with combined pipeline
587
- result = processor.process_frame(frame_rgb, use_temporal=(processed_count > 0))
588
-
589
- if result:
590
- alpha = result.alpha
591
- last_alpha = alpha
592
- method_used = result.method
593
- processing_times.append(result.processing_time)
594
  else:
595
- # Fallback to rembg
596
- alpha = segment_with_rembg(frame_rgb)
597
- if alpha is not None:
598
- last_alpha = alpha
599
- method_used = "Rembg"
600
- else:
601
- alpha = last_alpha if last_alpha is not None else np.ones((height, width), dtype=np.float32)
602
- method_used = "Previous/Fallback"
603
-
604
- processed_count += 1
605
-
 
 
 
 
 
606
  else:
607
- # Use last alpha for skipped frames
608
- alpha = last_alpha if last_alpha is not None else np.ones((height, width), dtype=np.float32)
609
- method_used = "Interpolated"
610
 
611
- # Apply alpha and composite
612
- if alpha.ndim == 2:
613
- alpha = np.expand_dims(alpha, axis=2)
614
 
615
- # High-quality compositing
616
  foreground = frame_rgb.astype(np.float32)
617
- background = background_resized.astype(np.float32)
618
 
 
619
  composite = foreground * alpha + background * (1 - alpha)
620
  composite = np.clip(composite, 0, 255).astype(np.uint8)
621
 
622
- # Convert back to BGR
623
  composite_bgr = cv2.cvtColor(composite, cv2.COLOR_RGB2BGR)
624
  out.write(composite_bgr)
625
 
626
  frame_count += 1
627
 
628
- # Progress update
629
- if progress_callback:
630
- progress = frame_count / total_frames
631
- if processing_times:
632
- avg_time = np.mean(processing_times[-10:])
633
- eta = avg_time * ((total_frames - frame_count) / frame_skip)
634
- else:
635
- eta = 0
636
- progress_callback(
637
- progress,
638
- f"{method_used} | Frame {frame_count}/{total_frames} | ETA: {eta:.1f}s"
639
- )
640
-
641
- # Memory cleanup
642
  if frame_count % 30 == 0 and CUDA_AVAILABLE:
643
  torch.cuda.empty_cache()
644
 
645
- # Release resources
 
 
646
  cap.release()
647
  out.release()
648
 
 
 
 
649
  if CUDA_AVAILABLE:
650
  torch.cuda.empty_cache()
651
  gc.collect()
652
 
653
- # Log statistics
654
- if processing_times:
655
- logger.info(f"βœ… Processing complete: {output_path}")
656
- logger.info(f"Average processing time: {np.mean(processing_times):.3f}s per frame")
657
- logger.info(f"Total processed frames: {processed_count}/{total_frames}")
658
 
659
- return output_path
 
 
660
 
661
  except Exception as e:
662
- logger.error(f"Video processing failed: {e}")
663
- return None
664
-
665
- # ============================================
666
- # STREAMLIT UI
667
- # ============================================
 
 
 
 
 
 
 
 
 
 
668
 
669
- def main():
670
- st.set_page_config(
671
- page_title="BackgroundFX - Lightning Fast",
672
- page_icon="πŸš€",
673
- layout="wide",
674
- initial_sidebar_state="expanded"
675
- )
676
-
677
- # Header
678
- st.title("πŸš€ BackgroundFX - Lightning-Fast Video Background Replacement")
679
- st.markdown("**Professional quality in seconds, not minutes! Powered by SAM2 + MatAnyone**")
680
-
681
- # System Status
682
- col1, col2, col3, col4 = st.columns(4)
683
-
684
- with col1:
685
- if CUDA_AVAILABLE:
686
- st.success(f"πŸš€ GPU: {GPU_NAME}")
687
- st.caption(f"VRAM: {GPU_MEMORY:.1f}GB")
688
- else:
689
- st.warning("πŸ’» CPU Mode")
690
-
691
- with col2:
692
- methods = []
693
- if processor.sam2_loaded:
694
- methods.append("SAM2")
695
- if processor.matanyone_loaded:
696
- methods.append("MatAnyone")
697
- if REMBG_AVAILABLE:
698
- methods.append("Rembg")
699
-
700
- if methods:
701
- st.info(f"βœ… Ready: {', '.join(methods)}")
702
- else:
703
- st.warning("⏳ Loading models...")
704
-
705
- with col3:
706
- if CUDA_AVAILABLE:
707
- allocated = torch.cuda.memory_allocated() / 1024**3
708
- st.metric("GPU Usage", f"{allocated:.1f}GB")
709
- else:
710
- st.metric("Mode", "CPU")
711
 
712
- with col4:
713
- # Speed indicator
714
- st.metric("Status", "Ready" if processor.sam2_loaded else "Loading")
 
715
 
716
- # Sidebar
717
- with st.sidebar:
718
- st.markdown("### ⚑ Speed Settings")
719
-
720
- # Speed mode selection
721
- speed_mode = st.select_slider(
722
- "Processing Speed",
723
- options=['ultra_fast', 'fast', 'balanced', 'quality'],
724
- value='balanced',
725
- format_func=lambda x: {
726
- 'ultra_fast': '⚑⚑⚑ Ultra Fast (3x)',
727
- 'fast': '⚑⚑ Fast (2x)',
728
- 'balanced': '⚑ Balanced',
729
- 'quality': '🎨 Quality'
730
- }[x]
731
- )
732
-
733
- # Speed mode info
734
- speed_info = {
735
- 'ultra_fast': "Process every 3rd frame\n~5 sec for 10 sec video",
736
- 'fast': "Process every 2nd frame\n~10 sec for 10 sec video",
737
- 'balanced': "Process all frames\n~15 sec for 10 sec video",
738
- 'quality': "Full processing\n~20 sec for 10 sec video"
739
  }
740
- st.info(speed_info[speed_mode])
741
-
742
- st.markdown("---")
743
-
744
- # Processing info
745
- st.markdown("### 🎯 Pipeline")
746
-
747
- if processor.sam2_loaded and processor.matanyone_loaded:
748
- st.success("SAM2 + MatAnyone Combined")
749
- st.caption("Best quality mode active")
750
- elif processor.sam2_loaded:
751
- st.info("SAM2 Only")
752
- st.caption("Good quality, fast processing")
753
- else:
754
- st.warning("Initializing...")
755
-
756
- st.markdown("---")
757
-
758
- # System info
759
- st.markdown("### πŸ“Š System")
760
 
761
- if CUDA_AVAILABLE:
762
- allocated = torch.cuda.memory_allocated() / 1024**3
763
- reserved = torch.cuda.memory_reserved() / 1024**3
764
-
765
- st.metric("Memory", f"{allocated:.1f}/{GPU_MEMORY:.0f} GB")
766
-
767
- usage_percent = (allocated / GPU_MEMORY) * 100 if GPU_MEMORY else 0
768
- st.progress(min(usage_percent / 100, 1.0))
769
-
770
- # GPU details
771
- with st.expander("GPU Details"):
772
- st.code(f"""
773
- Device: {GPU_NAME}
774
- VRAM: {GPU_MEMORY:.1f} GB
775
- Used: {allocated:.2f} GB
776
- Reserved: {reserved:.2f} GB
777
- PyTorch: {torch.__version__}
778
- CUDA: {torch.version.cuda if CUDA_AVAILABLE else 'N/A'}
779
- """)
780
-
781
- # Main content
782
- col1, col2 = st.columns(2)
783
-
784
- with col1:
785
- st.markdown("### πŸ“Ή Video Input")
786
 
787
- uploaded_video = st.file_uploader(
788
- "Upload your video",
789
- type=['mp4', 'avi', 'mov', 'mkv'],
790
- help="Recommended: 10-30 seconds for best performance"
791
- )
792
 
793
- if uploaded_video:
794
- # Save video
795
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
796
- tmp_file.write(uploaded_video.read())
797
- video_path = tmp_file.name
798
-
799
- st.video(uploaded_video)
800
-
801
- # Get video info
802
- cap = cv2.VideoCapture(video_path)
803
- fps = int(cap.get(cv2.CAP_PROP_FPS))
804
- frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
805
- duration = frames / fps if fps > 0 else 0
806
- cap.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807
 
808
- st.success(f"βœ… Ready: {duration:.1f}s @ {fps} FPS")
809
- else:
810
- video_path = None
811
-
812
- with col2:
813
- st.markdown("### 🎨 Background")
814
-
815
- # Quick background selection
816
- backgrounds = get_background_options()
817
- selected_bg = st.selectbox(
818
- "Choose background",
819
- options=list(backgrounds.keys()),
820
- index=0
821
- )
822
-
823
- background_option = backgrounds[selected_bg]
 
 
 
 
 
 
 
824
 
825
- # Preview
826
- if background_option:
827
- preview_bg = load_background_image(background_option)
828
- preview_bg_resized = cv2.resize(preview_bg, (640, 360))
829
- st.image(preview_bg_resized, caption=selected_bg, use_container_width=True)
830
-
831
- # Process button
832
- if video_path and st.button("πŸš€ Process Video", type="primary", use_container_width=True):
833
-
834
- # Progress tracking
835
- progress_bar = st.progress(0)
836
- status_text = st.empty()
837
- time_text = st.empty()
838
-
839
- def update_progress(progress, message):
840
- progress_bar.progress(progress)
841
- status_text.text(message)
842
- elapsed = time.time() - start_time
843
- time_text.text(f"⏱️ Elapsed: {elapsed:.1f}s")
844
-
845
- # Process video
846
- start_time = time.time()
847
-
848
- result_path = process_video(
849
- video_path,
850
- background_option,
851
- speed_mode=speed_mode,
852
- progress_callback=update_progress
853
  )
854
 
855
- processing_time = time.time() - start_time
856
-
857
- if result_path and os.path.exists(result_path):
858
- # Success
859
- status_text.text(f"βœ… Complete in {processing_time:.1f} seconds!")
860
- time_text.text(f"πŸš€ Speed: {frames/processing_time:.1f} FPS")
861
-
862
- # Load result
863
- with open(result_path, 'rb') as f:
864
- result_data = f.read()
865
-
866
- st.markdown("### 🎬 Result")
867
- st.video(result_data)
868
-
869
- # Download button
870
- col1, col2, col3 = st.columns([1, 2, 1])
871
- with col2:
872
- st.download_button(
873
- label="πŸ’Ύ Download Video",
874
- data=result_data,
875
- file_name=f"backgroundfx_{uploaded_video.name}",
876
- mime="video/mp4",
877
- use_container_width=True
878
- )
879
-
880
- # Stats
881
- st.success(f"""
882
- ✨ **Processing Complete!**
883
- - Time: {processing_time:.1f} seconds
884
- - Speed: {frames/processing_time:.1f} FPS
885
- - Method: {processor.previous_result.method if processor.previous_result else 'Unknown'}
886
- - Mode: {speed_mode.replace('_', ' ').title()}
887
  """)
888
-
889
- # Cleanup
890
- os.unlink(result_path)
891
- else:
892
- st.error("❌ Processing failed! Please try again.")
893
-
894
- # Cleanup temp
895
- if video_path and os.path.exists(video_path):
896
- os.unlink(video_path)
897
 
 
898
  if __name__ == "__main__":
899
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ BackgroundFX - Enhanced SAM2 Video Background Replacer for Hugging Face Spaces
4
+ Professional video background replacement with optimized lazy loading and memory management
 
5
  """
6
 
7
+ import gradio as gr
8
  import cv2
9
  import numpy as np
10
  import tempfile
 
17
  import torch
18
  import time
19
  from pathlib import Path
20
+ import hashlib
 
21
 
22
  # Configure logging
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
  logger = logging.getLogger(__name__)
25
 
26
+ # Constants
27
+ MAX_VIDEO_DURATION = 300 # 5 minutes max for free tier
28
+ MAX_FRAMES_BATCH = 100 # Process in batches to manage memory
29
+ SUPPORTED_VIDEO_FORMATS = ['.mp4', '.avi', '.mov', '.mkv', '.webm']
30
 
31
+ # GPU Setup and Detection
32
+ def setup_gpu():
33
+ """Setup GPU with detailed information and optimization"""
34
+ if torch.cuda.is_available():
35
+ gpu_name = torch.cuda.get_device_name(0)
36
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
37
+ torch.cuda.init()
38
+ torch.cuda.set_device(0)
39
+ torch.backends.cudnn.benchmark = True
40
+
41
+ # Optimize for common GPU types
42
+ gpu_optimizations = {
43
+ "T4": {"use_half": True, "batch_size": 1},
44
+ "V100": {"use_half": False, "batch_size": 2},
45
+ "A10": {"use_half": True, "batch_size": 2},
46
+ "A100": {"use_half": False, "batch_size": 4}
47
+ }
48
+
49
+ gpu_type = None
50
+ for gpu in gpu_optimizations:
51
+ if gpu in gpu_name:
52
+ gpu_type = gpu
53
+ break
54
+
55
+ return True, gpu_name, gpu_memory, gpu_type
56
+ return False, None, 0, None
57
+
58
+ CUDA_AVAILABLE, GPU_NAME, GPU_MEMORY, GPU_TYPE = setup_gpu()
59
+ DEVICE = 'cuda' if CUDA_AVAILABLE else 'cpu'
60
+
61
+ logger.info(f"Device: {DEVICE} | GPU: {GPU_NAME} | Memory: {GPU_MEMORY:.1f}GB | Type: {GPU_TYPE}")
62
+
63
+ # Enhanced SAM2 Lazy Loader with Caching
64
+ class SAM2EnhancedLazy:
65
+ def __init__(self):
66
+ self.predictor = None
67
+ self.current_model_size = None
68
+ self.model_cache_dir = Path(tempfile.gettempdir()) / "sam2_cache"
69
+ self.model_cache_dir.mkdir(exist_ok=True)
70
+
71
+ self.models = {
72
+ "tiny": {
73
+ "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
74
+ "config": "sam2_hiera_t.yaml",
75
+ "size_mb": 38,
76
+ "description": "Fastest, lowest memory"
77
+ },
78
+ "small": {
79
+ "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
80
+ "config": "sam2_hiera_s.yaml",
81
+ "size_mb": 185,
82
+ "description": "Balanced speed/quality"
83
+ },
84
+ "base": {
85
+ "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
86
+ "config": "sam2_hiera_b+.yaml",
87
+ "size_mb": 320,
88
+ "description": "Best quality, slower"
89
+ }
90
+ }
91
 
92
+ def get_model_path(self, model_size):
93
+ """Get cached model path"""
94
+ model_name = f"sam2_{model_size}.pt"
95
+ return self.model_cache_dir / model_name
96
+
97
+ def clear_model(self):
98
+ """Clear current model from memory"""
99
+ if self.predictor:
100
+ del self.predictor
101
+ self.predictor = None
102
+ self.current_model_size = None
103
+
104
+ if CUDA_AVAILABLE:
105
+ torch.cuda.empty_cache()
106
+ gc.collect()
107
+ logger.info("SAM2 model cleared from memory")
108
+
109
+ def download_model(self, model_size, progress_fn=None):
110
+ """Download model with progress tracking and verification"""
111
+ model_info = self.models[model_size]
112
+ model_path = self.get_model_path(model_size)
113
+
114
+ if model_path.exists():
115
+ logger.info(f"Model {model_size} already cached")
116
+ return model_path
117
+
118
+ try:
119
+ logger.info(f"Downloading SAM2 {model_size} model...")
120
+ response = requests.get(model_info['url'], stream=True)
121
+ response.raise_for_status()
122
 
123
+ total_size = int(response.headers.get('content-length', 0))
124
+ downloaded = 0
 
125
 
126
+ with open(model_path, 'wb') as f:
127
+ for chunk in response.iter_content(chunk_size=8192):
128
+ if chunk:
129
+ f.write(chunk)
130
+ downloaded += len(chunk)
131
+ if progress_fn and total_size > 0:
132
+ progress = downloaded / total_size * 0.4 # 40% of total progress
133
+ progress_fn(progress, f"Downloading SAM2 {model_size} ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)")
134
 
135
+ logger.info(f"SAM2 {model_size} downloaded successfully")
136
+ return model_path
 
 
137
 
138
+ except Exception as e:
139
+ logger.error(f"Failed to download SAM2 {model_size}: {e}")
140
+ if model_path.exists():
141
+ model_path.unlink()
142
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ def load_model(self, model_size, progress_fn=None):
145
+ """Load SAM2 model with optimization"""
 
 
 
 
 
 
 
 
 
 
 
 
146
  try:
147
+ # Import SAM2 (lazy import to avoid import errors if not available)
148
+ try:
149
+ from sam2.build_sam import build_sam2
150
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
151
+ except ImportError as e:
152
+ logger.error("SAM2 not available. Install with: pip install segment-anything-2")
153
+ raise ImportError("SAM2 package not found") from e
154
+
155
+ model_path = self.download_model(model_size, progress_fn)
156
+
157
+ if progress_fn:
158
+ progress_fn(0.5, f"Loading SAM2 {model_size} model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  # Build model
161
+ model_config = self.models[model_size]["config"]
162
+ sam2_model = build_sam2(model_config, str(model_path), device=DEVICE)
 
 
 
163
 
164
+ # Apply GPU optimizations
165
+ if CUDA_AVAILABLE and GPU_TYPE in ["T4", "A10"]:
166
  sam2_model = sam2_model.half()
167
+ logger.info(f"Applied half precision for {GPU_TYPE}")
168
 
169
+ self.predictor = SAM2ImagePredictor(sam2_model)
170
+ self.current_model_size = model_size
171
 
172
+ if progress_fn:
173
+ progress_fn(0.6, f"SAM2 {model_size} loaded successfully!")
174
+
175
+ logger.info(f"SAM2 {model_size} model loaded and ready")
176
+ return self.predictor
177
 
178
  except Exception as e:
179
+ logger.error(f"Failed to load SAM2 {model_size}: {e}")
180
+ self.clear_model()
181
+ raise
182
+
183
+ def get_predictor(self, model_size="tiny", progress_fn=None):
184
+ """Get predictor, loading if necessary"""
185
+ if self.predictor is None or self.current_model_size != model_size:
186
+ self.clear_model()
187
+ return self.load_model(model_size, progress_fn)
188
+ return self.predictor
189
 
190
+ def segment_image(self, image, model_size="tiny", progress_fn=None):
191
+ """Segment image with SAM2"""
192
+ predictor = self.get_predictor(model_size, progress_fn)
193
+
194
  try:
195
+ predictor.set_image(image)
196
+ h, w = image.shape[:2]
197
+
198
+ # Smart point selection for better segmentation
199
+ center_points = [
200
+ [w//2, h//2], # Center
201
+ [w//2, h//3], # Upper center
202
+ [w//2, 2*h//3], # Lower center
203
+ [w//3, h//2], # Left center
204
+ [2*w//3, h//2] # Right center
205
+ ]
206
+
207
+ point_coords = np.array(center_points)
208
+ point_labels = np.ones(len(point_coords))
209
+
210
+ masks, scores, logits = predictor.predict(
211
+ point_coords=point_coords,
212
+ point_labels=point_labels,
213
+ multimask_output=True
 
 
 
 
 
 
 
 
 
 
214
  )
215
 
216
+ # Select best mask
217
+ best_mask_idx = scores.argmax()
218
+ best_mask = masks[best_mask_idx]
219
+ best_score = scores[best_mask_idx]
220
+
221
+ # Post-process mask for better edges
222
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
223
+ best_mask = cv2.morphologyEx(best_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
224
+ best_mask = cv2.GaussianBlur(best_mask.astype(np.float32), (3, 3), 1.0)
225
 
226
+ return best_mask, float(best_score)
 
227
 
228
  except Exception as e:
229
+ logger.error(f"Segmentation failed: {e}")
230
+ return None, 0.0
231
+
232
+ # Global SAM2 loader
233
+ sam2_loader = SAM2EnhancedLazy()
234
+
235
+ # Video Validation
236
+ def validate_video(video_path):
237
+ """Comprehensive video validation"""
238
+ if not video_path or not os.path.exists(video_path):
239
+ return False, "No video file provided"
240
 
241
+ # Check file extension
242
+ file_ext = Path(video_path).suffix.lower()
243
+ if file_ext not in SUPPORTED_VIDEO_FORMATS:
244
+ return False, f"Unsupported format. Supported: {', '.join(SUPPORTED_VIDEO_FORMATS)}"
 
 
 
 
 
245
 
246
+ try:
247
+ cap = cv2.VideoCapture(video_path)
248
+ if not cap.isOpened():
249
+ return False, "Cannot open video file"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
+ # Get video properties
252
+ fps = cap.get(cv2.CAP_PROP_FPS)
253
+ frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
254
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
255
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
+ cap.release()
 
258
 
259
+ if fps <= 0 or frame_count <= 0:
260
+ return False, "Invalid video properties"
 
261
 
262
+ duration = frame_count / fps
 
 
263
 
264
+ # Check duration
265
+ if duration > MAX_VIDEO_DURATION:
266
+ return False, f"Video too long ({duration:.1f}s). Max: {MAX_VIDEO_DURATION}s"
 
267
 
268
+ # Check resolution
269
+ if width * height > 1920 * 1080:
270
+ return False, "Resolution too high (max 1920x1080)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ return True, f"Valid video: {duration:.1f}s, {width}x{height}, {fps:.1f}fps"
 
 
273
 
 
 
 
 
274
  except Exception as e:
275
+ return False, f"Video validation error: {str(e)}"
 
 
 
 
 
276
 
277
+ # Background Creation
278
  def create_gradient_background(width=1280, height=720, color1=(70, 130, 180), color2=(255, 140, 90)):
279
+ """Create smooth gradient background"""
280
  background = np.zeros((height, width, 3), dtype=np.uint8)
 
281
  for y in range(height):
282
  ratio = y / height
283
+ # Smooth interpolation
284
  r = int(color1[0] * (1 - ratio) + color2[0] * ratio)
285
+ g = int(color1[1] * (1 - ratio) + color2[1] * ratio)
286
  b = int(color1[2] * (1 - ratio) + color2[2] * ratio)
287
  background[y, :] = [r, g, b]
 
288
  return background
289
 
290
+ def get_background_presets():
291
+ """Get available background presets"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  return {
293
+ "gradient:ocean": ("Ocean Blue", (20, 120, 180), (135, 206, 235)),
294
+ "gradient:sunset": ("Sunset Orange", (255, 94, 77), (255, 154, 0)),
295
+ "gradient:forest": ("Forest Green", (34, 139, 34), (144, 238, 144)),
296
+ "gradient:purple": ("Purple Haze", (128, 0, 128), (221, 160, 221)),
297
+ "color:white": ("Pure White", None, None),
298
+ "color:black": ("Pure Black", None, None),
299
+ "color:green": ("Chroma Green", None, None),
300
+ "color:blue": ("Chroma Blue", None, None)
 
 
301
  }
302
 
303
+ def create_background_from_preset(preset, width, height):
304
+ """Create background from preset"""
305
+ presets = get_background_presets()
306
+
307
+ if preset not in presets:
308
+ return create_gradient_background(width, height)
309
+
310
+ name, color1, color2 = presets[preset]
311
+
312
+ if preset.startswith("gradient:"):
313
+ return create_gradient_background(width, height, color1, color2)
314
+ elif preset.startswith("color:"):
315
+ color_map = {
316
+ "white": [255, 255, 255],
317
+ "black": [0, 0, 0],
318
+ "green": [0, 255, 0],
319
+ "blue": [0, 0, 255]
320
+ }
321
+ color_name = preset.split(":")[1]
322
+ color = color_map.get(color_name, [255, 255, 255])
323
+ return np.full((height, width, 3), color, dtype=np.uint8)
324
 
325
+ def load_background_image(background_img, background_preset, target_width, target_height):
326
+ """Load and prepare background image"""
327
+ try:
328
+ if background_img is not None:
329
+ # Use uploaded image
330
+ background = np.array(background_img.convert('RGB'))
331
+ else:
332
+ # Use preset
333
+ background = create_background_from_preset(background_preset, target_width, target_height)
334
+
335
+ # Resize to target dimensions
336
+ if background.shape[:2] != (target_height, target_width):
337
+ background = cv2.resize(background, (target_width, target_height))
338
+
339
+ return background
340
+
341
+ except Exception as e:
342
+ logger.error(f"Background loading failed: {e}")
343
+ return create_gradient_background(target_width, target_height)
344
 
345
+ # Enhanced Video Processing
346
+ def process_video_enhanced(input_video, background_img, background_preset, model_size, edge_smoothing, progress=gr.Progress()):
347
+ """Enhanced video processing with better error handling and optimization"""
348
+
349
+ if input_video is None:
350
+ return None, "❌ Please upload a video file"
351
+
352
+ # Validate video
353
+ progress(0.02, desc="Validating video...")
354
+ is_valid, validation_msg = validate_video(input_video)
355
+ if not is_valid:
356
+ return None, f"❌ {validation_msg}"
357
+
358
+ logger.info(f"Video validation: {validation_msg}")
359
+
360
+ cap = None
361
+ out = None
362
+ output_path = None
363
 
 
 
 
 
 
 
364
  try:
365
+ # Get video properties
366
+ progress(0.05, desc="Reading video properties...")
367
+ cap = cv2.VideoCapture(input_video)
368
 
 
 
369
  fps = int(cap.get(cv2.CAP_PROP_FPS))
370
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
371
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
372
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
373
+ duration = total_frames / fps if fps > 0 else 0
374
 
375
+ logger.info(f"Video: {width}x{height}, {fps}fps, {total_frames} frames, {duration:.1f}s")
376
+
377
+ # Prepare background
378
+ progress(0.08, desc="Preparing background...")
379
+ background_image = load_background_image(background_img, background_preset, width, height)
380
+
381
+ # Setup output video
 
 
 
 
 
 
 
 
 
 
382
  output_path = tempfile.mktemp(suffix='.mp4')
383
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
384
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
385
 
386
+ if not out.isOpened():
387
+ raise Exception("Failed to create output video")
 
 
 
388
 
389
+ # Processing variables
390
  frame_count = 0
391
+ last_mask = None
392
+ processing_start_time = time.time()
 
393
 
394
+ # SAM2 progress callback
395
+ def sam2_progress(progress_val, message):
396
+ # Map SAM2 progress to overall progress (10%-40%)
397
+ overall_progress = 0.1 + (progress_val * 0.3)
398
+ progress(overall_progress, desc=message)
399
+
400
+ # Process frames
401
  while True:
402
  ret, frame = cap.read()
403
  if not ret:
404
  break
405
 
 
406
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
407
 
408
+ # Segment frame with SAM2
409
+ mask, confidence = sam2_loader.segment_image(frame_rgb, model_size, sam2_progress)
410
+
411
+ if mask is not None and confidence > 0.5:
412
+ current_mask = mask
413
+ last_mask = current_mask
414
+ else:
415
+ # Use last good mask or create fallback
416
+ if last_mask is not None:
417
+ current_mask = last_mask
418
+ logger.warning(f"Frame {frame_count}: Using previous mask (confidence: {confidence:.2f})")
 
419
  else:
420
+ # Create center-focused fallback mask
421
+ current_mask = np.zeros((height, width), dtype=np.float32)
422
+ center_x, center_y = width // 2, height // 2
423
+ y, x = np.ogrid[:height, :width]
424
+ mask_dist = np.sqrt((x - center_x)**2 + (y - center_y)**2)
425
+ current_mask = np.clip(1 - mask_dist / (min(width, height) * 0.3), 0, 1)
426
+ logger.warning(f"Frame {frame_count}: Using fallback mask")
427
+
428
+ # Apply edge smoothing
429
+ if edge_smoothing > 0:
430
+ kernel_size = int(edge_smoothing * 2) + 1
431
+ current_mask = cv2.GaussianBlur(current_mask, (kernel_size, kernel_size), edge_smoothing)
432
+
433
+ # Composite frame
434
+ if current_mask.ndim == 2:
435
+ alpha = np.expand_dims(current_mask, axis=2)
436
  else:
437
+ alpha = current_mask
 
 
438
 
439
+ # Ensure alpha is in correct range
440
+ alpha = np.clip(alpha, 0, 1)
 
441
 
 
442
  foreground = frame_rgb.astype(np.float32)
443
+ background = background_image.astype(np.float32)
444
 
445
+ # Advanced compositing
446
  composite = foreground * alpha + background * (1 - alpha)
447
  composite = np.clip(composite, 0, 255).astype(np.uint8)
448
 
449
+ # Convert back to BGR for output
450
  composite_bgr = cv2.cvtColor(composite, cv2.COLOR_RGB2BGR)
451
  out.write(composite_bgr)
452
 
453
  frame_count += 1
454
 
455
+ # Update progress
456
+ if frame_count % 5 == 0: # Update every 5 frames
457
+ frame_progress = frame_count / total_frames
458
+ overall_progress = 0.4 + (frame_progress * 0.55) # 40%-95%
459
+ elapsed_time = time.time() - processing_start_time
460
+ if frame_count > 0:
461
+ avg_time_per_frame = elapsed_time / frame_count
462
+ remaining_time = avg_time_per_frame * (total_frames - frame_count)
463
+ progress(overall_progress, desc=f"Processing frame {frame_count}/{total_frames} (ETA: {remaining_time:.0f}s)")
464
+
465
+ # Memory management
 
 
 
466
  if frame_count % 30 == 0 and CUDA_AVAILABLE:
467
  torch.cuda.empty_cache()
468
 
469
+ progress(0.98, desc="Finalizing video...")
470
+
471
+ # Cleanup
472
  cap.release()
473
  out.release()
474
 
475
+ # Clear SAM2 model to free memory
476
+ sam2_loader.clear_model()
477
+
478
  if CUDA_AVAILABLE:
479
  torch.cuda.empty_cache()
480
  gc.collect()
481
 
482
+ processing_time = time.time() - processing_start_time
483
+ logger.info(f"Processing completed in {processing_time:.1f}s")
 
 
 
484
 
485
+ progress(1.0, desc="Complete!")
486
+
487
+ return output_path, f"βœ… Successfully processed {duration:.1f}s video ({total_frames} frames) in {processing_time:.1f}s"
488
 
489
  except Exception as e:
490
+ error_msg = f"❌ Processing failed: {str(e)}"
491
+ logger.error(error_msg)
492
+
493
+ # Cleanup on error
494
+ try:
495
+ if cap:
496
+ cap.release()
497
+ if out:
498
+ out.release()
499
+ if output_path and os.path.exists(output_path):
500
+ os.unlink(output_path)
501
+ except:
502
+ pass
503
+
504
+ sam2_loader.clear_model()
505
+ return None, error_msg
506
 
507
+ # Gradio Interface
508
+ def create_interface():
509
+ """Create the Gradio interface"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
 
511
+ # Get background presets for dropdown
512
+ preset_choices = [("Custom (upload image)", "custom")]
513
+ for key, (name, _, _) in get_background_presets().items():
514
+ preset_choices.append((name, key))
515
 
516
+ with gr.Blocks(
517
+ title="BackgroundFX Pro - SAM2 Powered",
518
+ theme=gr.themes.Soft(),
519
+ css="""
520
+ .gradio-container {
521
+ max-width: 1200px !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  }
523
+ .main-header {
524
+ text-align: center;
525
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
526
+ -webkit-background-clip: text;
527
+ -webkit-text-fill-color: transparent;
528
+ background-clip: text;
529
+ }
530
+ """
531
+ ) as demo:
 
 
 
 
 
 
 
 
 
 
 
532
 
533
+ gr.Markdown("""
534
+ # πŸŽ₯ BackgroundFX Pro - SAM2 Powered
535
+ **Professional AI video background replacement with advanced segmentation**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
+ Upload your video and let SAM2 AI automatically detect and replace the background with precision.
538
+ Optimized for Hugging Face Spaces with smart memory management and lazy loading.
539
+ """, elem_classes=["main-header"])
 
 
540
 
541
+ with gr.Row():
542
+ with gr.Column(scale=1):
543
+ gr.Markdown("### πŸ“€ Input Configuration")
544
+
545
+ video_input = gr.Video(
546
+ label="Upload Video",
547
+ height=300,
548
+ info="Supported: MP4, AVI, MOV, MKV, WebM (max 5 minutes)"
549
+ )
550
+
551
+ with gr.Tab("Background"):
552
+ background_preset = gr.Dropdown(
553
+ choices=preset_choices,
554
+ value="gradient:ocean",
555
+ label="Background Preset",
556
+ info="Choose a preset or upload custom image"
557
+ )
558
+
559
+ background_input = gr.Image(
560
+ label="Custom Background (optional)",
561
+ type="pil",
562
+ height=200,
563
+ info="Upload image to override preset"
564
+ )
565
+
566
+ with gr.Accordion("βš™οΈ AI Settings", open=True):
567
+ model_size = gr.Radio(
568
+ choices=[
569
+ ("Tiny (38MB) - Fastest", "tiny"),
570
+ ("Small (185MB) - Balanced", "small"),
571
+ ("Base (320MB) - Best Quality", "base")
572
+ ],
573
+ value="tiny",
574
+ label="SAM2 Model Size",
575
+ info="Larger models = better quality but slower processing"
576
+ )
577
+
578
+ edge_smoothing = gr.Slider(
579
+ minimum=0,
580
+ maximum=5,
581
+ value=1.0,
582
+ step=0.5,
583
+ label="Edge Smoothing",
584
+ info="Softens edges around subject (0 = sharp, 5 = very soft)"
585
+ )
586
+
587
+ process_btn = gr.Button(
588
+ "πŸš€ Replace Background",
589
+ variant="primary",
590
+ size="lg",
591
+ scale=2
592
+ )
593
 
594
+ with gr.Column(scale=1):
595
+ gr.Markdown("### πŸ“₯ Output")
596
+
597
+ video_output = gr.Video(
598
+ label="Processed Video",
599
+ height=400,
600
+ show_download_button=True
601
+ )
602
+
603
+ status_output = gr.Textbox(
604
+ label="Processing Status",
605
+ lines=3,
606
+ max_lines=5
607
+ )
608
+
609
+ gr.Markdown("""
610
+ ### πŸ’‘ Pro Tips
611
+ - **Best results:** Clear subject separation from background
612
+ - **Lighting:** Even lighting works best
613
+ - **Movement:** Minimal camera shake recommended
614
+ - **Processing:** ~30-60 seconds per minute of video
615
+ - **Memory:** Models auto-downloaded and cleared after use
616
+ """)
617
 
618
+ # System Information
619
+ with gr.Row():
620
+ with gr.Column():
621
+ if CUDA_AVAILABLE:
622
+ gr.Markdown(f"πŸš€ **GPU Acceleration:** {GPU_NAME} ({GPU_MEMORY:.1f}GB) | Type: {GPU_TYPE}")
623
+ else:
624
+ gr.Markdown("πŸ’» **CPU Mode** (GPU recommended for faster processing)")
625
+
626
+ with gr.Column():
627
+ gr.Markdown("πŸ“¦ **Storage:** 0MB persistent (True lazy loading)")
628
+
629
+ # Processing event
630
+ process_btn.click(
631
+ fn=process_video_enhanced,
632
+ inputs=[
633
+ video_input,
634
+ background_input,
635
+ background_preset,
636
+ model_size,
637
+ edge_smoothing
638
+ ],
639
+ outputs=[video_output, status_output],
640
+ show_progress=True
 
 
 
 
 
641
  )
642
 
643
+ # Examples section
644
+ with gr.Row():
645
+ gr.Markdown("""
646
+ ### 🎬 Examples & Use Cases
647
+ - **Content Creation:** Remove messy backgrounds for professional videos
648
+ - **Virtual Meetings:** Create custom backgrounds for video calls
649
+ - **Education:** Clean backgrounds for instructional videos
650
+ - **Social Media:** Eye-catching backgrounds for posts and stories
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
  """)
652
+
653
+ return demo
 
 
 
 
 
 
 
654
 
655
+ # Main execution
656
  if __name__ == "__main__":
657
+ # Setup logging
658
+ logger.info("Starting BackgroundFX Pro...")
659
+ logger.info(f"Device: {DEVICE}")
660
+ if CUDA_AVAILABLE:
661
+ logger.info(f"GPU: {GPU_NAME} ({GPU_MEMORY:.1f}GB)")
662
+
663
+ # Create and launch interface
664
+ demo = create_interface()
665
+
666
+ demo.queue(
667
+ concurrency_count=2, # Max 2 concurrent processes
668
+ max_size=10, # Max 10 in queue
669
+ api_open=False # Disable API for security
670
+ ).launch(
671
+ server_name="0.0.0.0",
672
+ server_port=7860,
673
+ share=False,
674
+ show_error=True,
675
+ quiet=False
676
+ )