MogensR commited on
Commit
424efea
Β·
1 Parent(s): dd32d7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +548 -465
app.py CHANGED
@@ -1,8 +1,8 @@
1
  #!/usr/bin/env python3
2
  """
3
  BackgroundFX - Professional Video Background Replacement
4
- Priority: MatAnyone > SAM2 > Rembg > OpenCV
5
- Optimized for HuggingFace Spaces L4 GPU
6
  """
7
 
8
  import streamlit as st
@@ -18,7 +18,8 @@
18
  import torch
19
  import time
20
  from pathlib import Path
21
- from tqdm import tqdm
 
22
 
23
  # Configure logging
24
  logging.basicConfig(level=logging.INFO)
@@ -29,10 +30,9 @@
29
  # ============================================
30
 
31
  def setup_gpu_environment():
32
- """Setup GPU environment with optimal settings for L4"""
33
  os.environ['CUDA_VISIBLE_DEVICES'] = '0'
34
- os.environ['TORCH_CUDA_ARCH_LIST'] = '8.9' # L4 architecture
35
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
36
 
37
  try:
38
  if torch.cuda.is_available():
@@ -45,13 +45,20 @@ def setup_gpu_environment():
45
  torch.cuda.init()
46
  torch.cuda.set_device(0)
47
 
48
- # Enable TF32 for L4
49
- torch.backends.cuda.matmul.allow_tf32 = True
50
- torch.backends.cudnn.allow_tf32 = True
51
  torch.backends.cudnn.benchmark = True
 
 
 
 
 
 
 
 
 
52
 
53
  # Warm up
54
- dummy = torch.randn(512, 512, device='cuda')
55
  del dummy
56
  torch.cuda.empty_cache()
57
 
@@ -68,35 +75,143 @@ def setup_gpu_environment():
68
  DEVICE = 'cuda' if CUDA_AVAILABLE else 'cpu'
69
 
70
  # ============================================
71
- # MATANYONE - PRIMARY METHOD (BEST QUALITY)
72
  # ============================================
73
 
74
- class MatAnyoneProcessor:
75
- """MatAnyone for superior video matting with temporal consistency"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def __init__(self):
78
- self.model = None
79
- self.predictor = None
80
- self.loaded = False
81
- self.previous_alpha = None
82
- self.previous_trimap = None
 
 
 
83
  self.frame_count = 0
84
 
85
  @st.cache_resource
86
- def load_model(_self):
87
- """Load MatAnyone model with caching"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  try:
89
- # Try to import MatAnyone
90
  from matanyone import MatAnyoneModel, MatAnyonePredictor
91
 
92
  # Download model if needed
93
- model_path = _self._download_model_if_needed()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # Load model
96
  model = MatAnyoneModel.from_pretrained(
97
- model_path,
98
- device=DEVICE,
99
- fp16=(DEVICE == 'cuda')
100
  )
101
 
102
  # Create predictor
@@ -108,249 +223,200 @@ def load_model(_self):
108
  )
109
 
110
  logger.info("βœ… MatAnyone loaded successfully")
111
- return model, predictor, True
112
 
113
- except ImportError:
114
- logger.warning("⚠️ MatAnyone not installed, falling back to other methods")
115
- return None, None, False
116
  except Exception as e:
117
- logger.error(f"❌ MatAnyone loading failed: {e}")
118
- return None, None, False
119
 
120
- def _download_model_if_needed(self):
121
- """Download MatAnyone model dynamically"""
122
- cache_dir = Path("/tmp/matanyone_models")
123
- cache_dir.mkdir(exist_ok=True)
124
-
125
- model_path = cache_dir / "matanyone_video.pth"
126
 
127
- if not model_path.exists():
128
- # MatAnyone model URL
129
- model_url = "https://huggingface.co/matanyone/matanyone-video/resolve/main/model.pth"
130
-
131
- with st.spinner("Downloading MatAnyone model (first time only)..."):
132
- response = requests.get(model_url, stream=True)
133
- total_size = int(response.headers.get('content-length', 0))
134
-
135
- progress_bar = st.progress(0)
136
- with open(model_path, 'wb') as f:
137
- downloaded = 0
138
- for chunk in response.iter_content(chunk_size=8192):
139
- f.write(chunk)
140
- downloaded += len(chunk)
141
- if total_size > 0:
142
- progress_bar.progress(downloaded / total_size)
143
-
144
- progress_bar.empty()
145
 
146
- return str(model_path)
147
 
148
- def process_frame(self, frame, use_temporal=True):
149
- """Process frame with MatAnyone"""
150
- if not self.loaded:
151
- self.model, self.predictor, self.loaded = self.load_model()
152
-
153
- if not self.loaded or self.predictor is None:
 
 
 
 
 
 
 
 
154
  return None
155
-
156
- try:
157
- # Generate or update trimap
158
- if use_temporal and self.previous_trimap is not None:
159
- trimap = self._update_trimap(self.previous_trimap, frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  else:
161
- trimap = self._generate_trimap(frame)
162
-
163
- # Process with temporal consistency
164
- if use_temporal and self.previous_alpha is not None:
165
- alpha = self.predictor.predict(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  image=frame,
167
  trimap=trimap,
168
- previous_alpha=self.previous_alpha,
169
- temporal_weight=0.3
170
  )
171
- else:
172
- alpha = self.predictor.predict(image=frame, trimap=trimap)
173
-
174
- # Refine alpha
175
- alpha = self._refine_alpha(alpha, frame)
176
-
177
- # Store for next frame
178
- self.previous_alpha = alpha.copy()
179
- self.previous_trimap = trimap.copy()
180
- self.frame_count += 1
 
 
 
 
 
 
 
 
181
 
182
- return alpha
 
 
 
183
 
184
- except Exception as e:
185
- logger.error(f"MatAnyone processing failed: {e}")
186
- return None
187
-
188
- def _generate_trimap(self, frame):
189
- """Generate initial trimap"""
190
- h, w = frame.shape[:2]
191
- trimap = np.zeros((h, w), dtype=np.uint8)
192
 
193
- # Create center region as unknown
194
- center_x, center_y = w // 2, h // 2
195
- radius_x, radius_y = w // 3, h // 2
196
 
197
- y, x = np.ogrid[:h, :w]
198
- mask = ((x - center_x)**2 / radius_x**2 + (y - center_y)**2 / radius_y**2) <= 1
199
- trimap[mask] = 128 # Unknown
 
 
 
 
 
200
 
201
- inner_mask = ((x - center_x)**2 / (radius_x*0.5)**2 + (y - center_y)**2 / (radius_y*0.5)**2) <= 1
202
- trimap[inner_mask] = 255 # Foreground
 
203
 
204
- return trimap
205
 
206
- def _update_trimap(self, prev_trimap, frame):
207
- """Update trimap with motion compensation"""
208
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
209
- unknown = (prev_trimap == 128).astype(np.uint8)
210
- unknown = cv2.dilate(unknown, kernel, iterations=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- trimap = prev_trimap.copy()
213
- trimap[unknown == 1] = 128
214
 
215
- return trimap
216
-
217
- def _refine_alpha(self, alpha, frame):
218
- """Refine alpha matte"""
219
- # Guided filter if available
220
- try:
221
- alpha = cv2.ximgproc.guidedFilter(frame, alpha, 5, 1e-4)
222
- except:
223
- # Fallback to Gaussian blur
224
- alpha = cv2.GaussianBlur(alpha, (5, 5), 0)
225
 
226
- return np.clip(alpha, 0, 1)
227
-
228
- def reset(self):
229
- """Reset for new video"""
230
- self.previous_alpha = None
231
- self.previous_trimap = None
232
- self.frame_count = 0
233
-
234
- # ============================================
235
- # SAM2 - SECONDARY METHOD (VIDEO OPTIMIZED)
236
- # ============================================
237
-
238
- class SAM2Processor:
239
- """SAM2 for video segmentation"""
240
-
241
- def __init__(self):
242
- self.predictor = None
243
- self.loaded = False
244
- self.previous_mask = None
245
 
246
- @st.cache_resource
247
- def load_model(_self):
248
- """Load SAM2 model dynamically"""
249
- try:
250
- from sam2.build_sam import build_sam2
251
- from sam2.sam2_image_predictor import SAM2ImagePredictor
252
-
253
- # Model configurations
254
- models = {
255
- 'large': ('sam2_hiera_l.yaml', 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt', 897),
256
- 'base': ('sam2_hiera_b+.yaml', 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt', 323),
257
- 'small': ('sam2_hiera_s.yaml', 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt', 155),
258
- 'tiny': ('sam2_hiera_t.yaml', 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt', 77)
259
- }
260
-
261
- # Select model based on GPU
262
- if CUDA_AVAILABLE and GPU_MEMORY > 20:
263
- model_key = 'large'
264
- elif CUDA_AVAILABLE and GPU_MEMORY > 10:
265
- model_key = 'base'
266
- else:
267
- model_key = 'tiny'
268
-
269
- config, url, size = models[model_key]
270
-
271
- # Download model
272
- cache_dir = Path("/tmp/sam2_models")
273
- cache_dir.mkdir(exist_ok=True)
274
- model_path = cache_dir / f"sam2_{model_key}.pt"
275
-
276
- if not model_path.exists():
277
- with st.spinner(f"Downloading SAM2 {model_key} model ({size}MB)..."):
278
- response = requests.get(url, stream=True)
279
- with open(model_path, 'wb') as f:
280
- for chunk in response.iter_content(chunk_size=8192):
281
- f.write(chunk)
282
-
283
- # Build model
284
- sam2_model = build_sam2(config, str(model_path), device=DEVICE)
285
- predictor = SAM2ImagePredictor(sam2_model)
286
-
287
- logger.info(f"βœ… SAM2 {model_key} loaded successfully")
288
- return predictor, True
289
-
290
- except ImportError:
291
- logger.warning("⚠️ SAM2 not installed")
292
- return None, False
293
- except Exception as e:
294
- logger.error(f"❌ SAM2 loading failed: {e}")
295
- return None, False
296
-
297
- def process_frame(self, frame, use_temporal=True):
298
- """Process frame with SAM2"""
299
- if not self.loaded:
300
- self.predictor, self.loaded = self.load_model()
301
-
302
- if not self.loaded or self.predictor is None:
303
- return None
304
-
305
- try:
306
- self.predictor.set_image(frame)
307
-
308
- h, w = frame.shape[:2]
309
-
310
- # Generate prompts
311
- if use_temporal and self.previous_mask is not None:
312
- y_coords, x_coords = np.where(self.previous_mask > 0.5)
313
- if len(y_coords) > 0:
314
- center_y = int(np.mean(y_coords))
315
- center_x = int(np.mean(x_coords))
316
- point_coords = np.array([[center_x, center_y]])
317
- else:
318
- point_coords = np.array([[w//2, h//2]])
319
- else:
320
- point_coords = np.array([[w//2, h//2], [w//2, h//3], [w//2, 2*h//3]])
321
-
322
- point_labels = np.ones(len(point_coords))
323
-
324
- # Predict
325
- masks, scores, _ = self.predictor.predict(
326
- point_coords=point_coords,
327
- point_labels=point_labels,
328
- multimask_output=True
329
- )
330
-
331
- mask = masks[np.argmax(scores)].astype(np.float32)
332
-
333
- # Temporal smoothing
334
- if use_temporal and self.previous_mask is not None:
335
- mask = 0.7 * mask + 0.3 * self.previous_mask
336
-
337
- # Refine
338
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
339
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
340
- mask = cv2.GaussianBlur(mask, (5, 5), 0)
341
-
342
- self.previous_mask = mask
343
- return mask
344
-
345
- except Exception as e:
346
- logger.error(f"SAM2 processing failed: {e}")
347
- return None
348
 
349
  def reset(self):
350
- self.previous_mask = None
 
 
 
351
 
352
  # ============================================
353
- # REMBG - TERTIARY METHOD (FAST)
354
  # ============================================
355
 
356
  REMBG_AVAILABLE = False
@@ -363,119 +429,108 @@ def reset(self):
363
  rembg_session = new_session('u2net_human_seg', providers=providers)
364
 
365
  # Warm up
366
- dummy_img = Image.new('RGB', (256, 256), color='white')
367
  _ = remove(dummy_img, session=rembg_session)
368
 
369
  REMBG_AVAILABLE = True
370
- logger.info(f"βœ… Rembg initialized with providers: {providers}")
371
 
372
  except Exception as e:
373
  logger.warning(f"⚠️ Rembg not available: {e}")
374
 
375
  def segment_with_rembg(frame):
376
- """Segment using Rembg"""
377
  if not REMBG_AVAILABLE:
378
  return None
379
 
380
  try:
381
  pil_image = Image.fromarray(frame)
382
- output = remove(
383
- pil_image,
384
- session=rembg_session,
385
- alpha_matting=True,
386
- alpha_matting_foreground_threshold=240,
387
- alpha_matting_background_threshold=10
388
- )
389
 
390
  output_array = np.array(output)
391
  if output_array.shape[2] == 4:
392
- mask = output_array[:, :, 3].astype(np.float32) / 255.0
393
- else:
394
- mask = np.ones((frame.shape[0], frame.shape[1]), dtype=np.float32)
395
-
396
- return mask
397
- except Exception as e:
398
- logger.error(f"Rembg segmentation failed: {e}")
399
  return None
400
-
401
- # ============================================
402
- # OPENCV - FALLBACK METHOD (ALWAYS WORKS)
403
- # ============================================
404
-
405
- def segment_with_opencv(frame):
406
- """Basic OpenCV segmentation"""
407
- try:
408
- hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV)
409
-
410
- lower_skin = np.array([0, 20, 70], dtype=np.uint8)
411
- upper_skin = np.array([20, 255, 255], dtype=np.uint8)
412
-
413
- mask = cv2.inRange(hsv, lower_skin, upper_skin)
414
-
415
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
416
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
417
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
418
-
419
- mask = mask.astype(np.float32) / 255.0
420
- mask = cv2.GaussianBlur(mask, (5, 5), 0)
421
-
422
- return mask
423
-
424
  except Exception as e:
425
- logger.error(f"OpenCV segmentation failed: {e}")
426
  return None
427
 
428
  # ============================================
429
  # BACKGROUND UTILITIES
430
  # ============================================
431
 
432
- def load_background_image(background_url):
433
- """Load background image from URL"""
434
- try:
435
- response = requests.get(background_url, timeout=10)
436
- response.raise_for_status()
437
- image = Image.open(BytesIO(response.content))
438
- return np.array(image.convert('RGB'))
439
- except Exception as e:
440
- logger.error(f"Failed to load background: {e}")
441
- return create_default_background()
442
-
443
- def create_default_background():
444
  """Create gradient background"""
445
- background = np.zeros((720, 1280, 3), dtype=np.uint8)
446
- for y in range(720):
447
- color_value = int(255 * (1 - y / 720))
448
- background[y, :] = [color_value, int(color_value * 0.7), int(color_value * 0.9)]
 
 
 
 
 
449
  return background
450
 
451
- def get_professional_backgrounds():
452
- """Professional background collection"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  return {
454
- "🏒 Modern Office": "https://images.unsplash.com/photo-1497366216548-37526070297c?w=1920&h=1080&fit=crop",
455
- "πŸŒ† City Skyline": "https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=1920&h=1080&fit=crop",
456
- "πŸ–οΈ Tropical Beach": "https://images.unsplash.com/photo-1507525428034-b723cf961d3e?w=1920&h=1080&fit=crop",
457
- "🌲 Forest Path": "https://images.unsplash.com/photo-1441974231531-c6227db76b6e?w=1920&h=1080&fit=crop",
458
- "🎨 Abstract Gradient": "https://images.unsplash.com/photo-1557683316-973673baf926?w=1920&h=1080&fit=crop",
459
- "πŸ”οΈ Mountain Vista": "https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=1920&h=1080&fit=crop",
460
- "πŸŒ… Sunset Sky": "https://images.unsplash.com/photo-1495616811223-4d98c6e9c869?w=1920&h=1080&fit=crop",
461
- "πŸ’Ό Conference Room": "https://images.unsplash.com/photo-1497366811353-6870744d04b2?w=1920&h=1080&fit=crop",
462
- "🎬 Studio Setup": "https://images.unsplash.com/photo-1565438222132-3654b8b88d4a?w=1920&h=1080&fit=crop",
463
- "πŸŒƒ Night City": "https://images.unsplash.com/photo-1519501025264-65ba15a82390?w=1920&h=1080&fit=crop"
464
  }
465
 
466
  # ============================================
467
  # VIDEO PROCESSING PIPELINE
468
  # ============================================
469
 
470
- # Initialize processors
471
- matanyone_processor = MatAnyoneProcessor()
472
- sam2_processor = SAM2Processor()
473
 
474
- def process_video(video_path, background_url, method='auto', progress_callback=None):
475
- """Process video with selected method"""
 
 
 
 
 
 
 
 
476
  try:
477
  # Load background
478
- background_image = load_background_image(background_url)
479
 
480
  # Open video
481
  cap = cv2.VideoCapture(video_path)
@@ -486,97 +541,108 @@ def process_video(video_path, background_url, method='auto', progress_callback=N
486
 
487
  logger.info(f"Processing video: {width}x{height}, {total_frames} frames, {fps} FPS")
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  # Create output
490
  output_path = tempfile.mktemp(suffix='.mp4')
491
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
492
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
493
 
494
- # Resize background
495
  background_resized = cv2.resize(background_image, (width, height))
496
 
497
- # Reset processors
498
- matanyone_processor.reset()
499
- sam2_processor.reset()
500
 
501
  frame_count = 0
 
502
  processing_times = []
 
503
 
504
  while True:
505
  ret, frame = cap.read()
506
  if not ret:
507
  break
508
 
509
- start_time = time.time()
510
-
511
  # Convert BGR to RGB
512
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
513
 
514
- # Select method and process
515
- mask = None
516
- method_used = "None"
517
-
518
- if method == 'auto' or method == 'matanyone':
519
- # Try MatAnyone first (BEST)
520
- mask = matanyone_processor.process_frame(frame_rgb, use_temporal=(frame_count > 0))
521
- if mask is not None:
522
- method_used = "MatAnyone"
523
-
524
- if mask is None and (method == 'auto' or method == 'sam2'):
525
- # Try SAM2 (GOOD)
526
- mask = sam2_processor.process_frame(frame_rgb, use_temporal=(frame_count > 0))
527
- if mask is not None:
528
- method_used = "SAM2"
529
-
530
- if mask is None and (method == 'auto' or method == 'rembg'):
531
- # Try Rembg (FAST)
532
- mask = segment_with_rembg(frame_rgb)
533
- if mask is not None:
534
- method_used = "Rembg"
535
-
536
- if mask is None:
537
- # Fallback to OpenCV
538
- mask = segment_with_opencv(frame_rgb)
539
- method_used = "OpenCV"
540
-
541
- # Apply mask and composite
542
- if mask is not None:
543
- if mask.ndim == 2:
544
- mask = np.expand_dims(mask, axis=2)
545
 
546
- # High-quality compositing
547
- foreground = frame_rgb.astype(np.float32)
548
- background = background_resized.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
- composite = foreground * mask + background * (1 - mask)
551
- composite = np.clip(composite, 0, 255).astype(np.uint8)
552
  else:
553
- composite = frame_rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
554
 
555
  # Convert back to BGR
556
  composite_bgr = cv2.cvtColor(composite, cv2.COLOR_RGB2BGR)
557
  out.write(composite_bgr)
558
 
559
- # Track time
560
- processing_time = time.time() - start_time
561
- processing_times.append(processing_time)
562
-
563
  frame_count += 1
564
 
565
  # Progress update
566
  if progress_callback:
567
  progress = frame_count / total_frames
568
- avg_time = np.mean(processing_times[-10:])
569
- eta = avg_time * (total_frames - frame_count)
 
 
 
570
  progress_callback(
571
  progress,
572
- f"{method_used}: Frame {frame_count}/{total_frames} | ETA: {eta:.1f}s"
573
  )
574
 
575
  # Memory cleanup
576
- if frame_count % 50 == 0 and CUDA_AVAILABLE:
577
  torch.cuda.empty_cache()
578
 
579
- # Release
580
  cap.release()
581
  out.release()
582
 
@@ -584,8 +650,11 @@ def process_video(video_path, background_url, method='auto', progress_callback=N
584
  torch.cuda.empty_cache()
585
  gc.collect()
586
 
587
- logger.info(f"βœ… Video processing complete: {output_path}")
588
- logger.info(f"Average time per frame: {np.mean(processing_times):.3f}s")
 
 
 
589
 
590
  return output_path
591
 
@@ -599,15 +668,15 @@ def process_video(video_path, background_url, method='auto', progress_callback=N
599
 
600
  def main():
601
  st.set_page_config(
602
- page_title="BackgroundFX - Professional Video Processing",
603
- page_icon="🎬",
604
  layout="wide",
605
  initial_sidebar_state="expanded"
606
  )
607
 
608
  # Header
609
- st.title("🎬 BackgroundFX - Professional Video Background Replacement")
610
- st.markdown("**Production-quality processing with MatAnyone, SAM2, and Rembg**")
611
 
612
  # System Status
613
  col1, col2, col3, col4 = st.columns(4)
@@ -617,18 +686,21 @@ def main():
617
  st.success(f"πŸš€ GPU: {GPU_NAME}")
618
  st.caption(f"VRAM: {GPU_MEMORY:.1f}GB")
619
  else:
620
- st.info("πŸ’» CPU Mode")
621
 
622
  with col2:
623
  methods = []
624
- if matanyone_processor.loaded:
625
- methods.append("MatAnyone")
626
- if sam2_processor.loaded:
627
  methods.append("SAM2")
 
 
628
  if REMBG_AVAILABLE:
629
  methods.append("Rembg")
630
- methods.append("OpenCV")
631
- st.info(f"πŸ“¦ Methods: {', '.join(methods)}")
 
 
 
632
 
633
  with col3:
634
  if CUDA_AVAILABLE:
@@ -638,79 +710,73 @@ def main():
638
  st.metric("Mode", "CPU")
639
 
640
  with col4:
641
- st.metric("Device", DEVICE.upper())
 
642
 
643
  # Sidebar
644
  with st.sidebar:
645
- st.markdown("### βš™οΈ Processing Options")
646
-
647
- # Method selection with quality indicators
648
- method_options = {
649
- 'auto': 'Auto (Best Available)',
650
- 'matanyone': 'MatAnyone (β˜…β˜…β˜…β˜…β˜… Production)',
651
- 'sam2': 'SAM2 (β˜…β˜…β˜…β˜… Video-Optimized)',
652
- 'rembg': 'Rembg (β˜…β˜…β˜… Fast)',
653
- 'opencv': 'OpenCV (β˜… Fallback)'
 
 
 
 
 
 
 
 
 
 
 
 
654
  }
 
655
 
656
- selected_method = st.selectbox(
657
- "Segmentation Method",
658
- options=list(method_options.keys()),
659
- format_func=lambda x: method_options[x],
660
- index=0
661
- )
662
 
663
- # Method info
664
- if selected_method == 'matanyone':
665
- st.info("""
666
- **MatAnyone Advantages:**
667
- β€’ Perfect hair/edge details
668
- β€’ Temporal consistency
669
- β€’ Alpha matting quality
670
- β€’ No flicker in video
671
- """)
672
- elif selected_method == 'sam2':
673
- st.info("""
674
- **SAM2 Advantages:**
675
- β€’ Designed for video
676
- β€’ Good temporal flow
677
- β€’ Automatic prompting
678
- """)
679
- elif selected_method == 'rembg':
680
- st.info("""
681
- **Rembg Advantages:**
682
- β€’ Fast processing
683
- β€’ Good for photos
684
- β€’ Easy to use
685
- """)
686
 
687
  st.markdown("---")
688
 
689
  # System info
690
- st.markdown("### πŸ“Š System Information")
691
 
692
  if CUDA_AVAILABLE:
693
  allocated = torch.cuda.memory_allocated() / 1024**3
694
  reserved = torch.cuda.memory_reserved() / 1024**3
695
- free = GPU_MEMORY - reserved if GPU_MEMORY else 0
696
 
697
- st.metric("GPU Memory", f"{allocated:.2f} / {GPU_MEMORY:.1f} GB")
698
 
699
  usage_percent = (allocated / GPU_MEMORY) * 100 if GPU_MEMORY else 0
700
  st.progress(min(usage_percent / 100, 1.0))
701
 
 
702
  with st.expander("GPU Details"):
703
  st.code(f"""
704
  Device: {GPU_NAME}
705
  VRAM: {GPU_MEMORY:.1f} GB
706
- Allocated: {allocated:.2f} GB
707
  Reserved: {reserved:.2f} GB
708
- Free: {free:.2f} GB
709
  PyTorch: {torch.__version__}
710
  CUDA: {torch.version.cuda if CUDA_AVAILABLE else 'N/A'}
711
  """)
712
- else:
713
- st.info("Running in CPU mode")
714
 
715
  # Main content
716
  col1, col2 = st.columns(2)
@@ -721,7 +787,7 @@ def main():
721
  uploaded_video = st.file_uploader(
722
  "Upload your video",
723
  type=['mp4', 'avi', 'mov', 'mkv'],
724
- help="Maximum recommended: 30 seconds for best performance"
725
  )
726
 
727
  if uploaded_video:
@@ -731,28 +797,36 @@ def main():
731
  video_path = tmp_file.name
732
 
733
  st.video(uploaded_video)
734
- st.success(f"βœ… Video ready: {uploaded_video.name}")
 
 
 
 
 
 
 
 
735
  else:
736
  video_path = None
737
 
738
  with col2:
739
- st.markdown("### πŸ–ΌοΈ Background Selection")
740
 
741
- backgrounds = get_professional_backgrounds()
742
- selected_bg_name = st.selectbox(
743
- "Choose a background",
 
744
  options=list(backgrounds.keys()),
745
  index=0
746
  )
747
 
748
- background_url = backgrounds[selected_bg_name]
749
 
750
  # Preview
751
- try:
752
- bg_image = load_background_image(background_url)
753
- st.image(bg_image, caption=selected_bg_name, use_container_width=True)
754
- except:
755
- st.error("Failed to load background preview")
756
 
757
  # Process button
758
  if video_path and st.button("πŸš€ Process Video", type="primary", use_container_width=True):
@@ -760,27 +834,30 @@ def main():
760
  # Progress tracking
761
  progress_bar = st.progress(0)
762
  status_text = st.empty()
 
763
 
764
  def update_progress(progress, message):
765
  progress_bar.progress(progress)
766
  status_text.text(message)
 
 
767
 
768
  # Process video
769
- with st.spinner("Processing video..."):
770
- start_time = time.time()
771
-
772
- result_path = process_video(
773
- video_path,
774
- background_url,
775
- method=selected_method,
776
- progress_callback=update_progress
777
- )
778
-
779
- processing_time = time.time() - start_time
780
 
781
  if result_path and os.path.exists(result_path):
782
  # Success
783
- status_text.text(f"βœ… Processing complete in {processing_time:.1f} seconds!")
 
784
 
785
  # Load result
786
  with open(result_path, 'rb') as f:
@@ -789,22 +866,28 @@ def update_progress(progress, message):
789
  st.markdown("### 🎬 Result")
790
  st.video(result_data)
791
 
792
- # Download
793
- st.download_button(
794
- label="πŸ’Ύ Download Processed Video",
795
- data=result_data,
796
- file_name=f"backgroundfx_{uploaded_video.name}",
797
- mime="video/mp4",
798
- use_container_width=True
799
- )
 
 
 
 
 
 
 
 
 
 
 
800
 
801
  # Cleanup
802
  os.unlink(result_path)
803
-
804
- # Stats
805
- if CUDA_AVAILABLE:
806
- allocated = torch.cuda.memory_allocated() / 1024**3
807
- st.info(f"Processing completed using {allocated:.1f}GB GPU memory")
808
  else:
809
  st.error("❌ Processing failed! Please try again.")
810
 
 
1
  #!/usr/bin/env python3
2
  """
3
  BackgroundFX - Professional Video Background Replacement
4
+ Combined Pipeline: SAM2 (segmentation) + MatAnyone (matting refinement)
5
+ Optimized for HuggingFace Spaces T4 GPU (16GB VRAM)
6
  """
7
 
8
  import streamlit as st
 
18
  import torch
19
  import time
20
  from pathlib import Path
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Dict, Tuple
23
 
24
  # Configure logging
25
  logging.basicConfig(level=logging.INFO)
 
30
  # ============================================
31
 
32
  def setup_gpu_environment():
33
+ """Setup GPU environment optimized for T4"""
34
  os.environ['CUDA_VISIBLE_DEVICES'] = '0'
35
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'
 
36
 
37
  try:
38
  if torch.cuda.is_available():
 
45
  torch.cuda.init()
46
  torch.cuda.set_device(0)
47
 
48
+ # T4 optimizations
 
 
49
  torch.backends.cudnn.benchmark = True
50
+ torch.backends.cudnn.deterministic = False
51
+
52
+ # T4 doesn't support TF32
53
+ if 'T4' in gpu_name:
54
+ torch.backends.cuda.matmul.allow_tf32 = False
55
+ torch.backends.cudnn.allow_tf32 = False
56
+ else:
57
+ torch.backends.cuda.matmul.allow_tf32 = True
58
+ torch.backends.cudnn.allow_tf32 = True
59
 
60
  # Warm up
61
+ dummy = torch.randn(256, 256, device='cuda')
62
  del dummy
63
  torch.cuda.empty_cache()
64
 
 
75
  DEVICE = 'cuda' if CUDA_AVAILABLE else 'cpu'
76
 
77
  # ============================================
78
+ # DATA STRUCTURES
79
  # ============================================
80
 
81
+ @dataclass
82
+ class ProcessingResult:
83
+ """Container for processing results"""
84
+ alpha: np.ndarray # Final alpha matte
85
+ sam2_mask: Optional[np.ndarray] = None # SAM2 coarse mask
86
+ trimap: Optional[np.ndarray] = None # Generated trimap
87
+ method: str = "unknown"
88
+ processing_time: float = 0.0
89
+
90
+ # ============================================
91
+ # COMBINED SAM2 + MATANYONE PROCESSOR
92
+ # ============================================
93
+
94
+ class CombinedProcessor:
95
+ """
96
+ Combines SAM2 and MatAnyone for ultimate quality
97
+ SAM2: Initial segmentation (find the person)
98
+ MatAnyone: Alpha matting refinement (perfect edges)
99
+ """
100
 
101
  def __init__(self):
102
+ self.sam2_predictor = None
103
+ self.matanyone_model = None
104
+ self.sam2_loaded = False
105
+ self.matanyone_loaded = False
106
+ self.device = DEVICE
107
+
108
+ # Temporal consistency
109
+ self.previous_result = None
110
  self.frame_count = 0
111
 
112
  @st.cache_resource
113
+ def load_sam2(_self):
114
+ """Load SAM2 model for segmentation"""
115
+ try:
116
+ from sam2.build_sam import build_sam2
117
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
118
+
119
+ # Model selection based on available VRAM
120
+ if GPU_MEMORY >= 15:
121
+ model_config = {
122
+ 'name': 'base_plus',
123
+ 'config': 'sam2_hiera_b+.yaml',
124
+ 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt',
125
+ 'size': 323
126
+ }
127
+ elif GPU_MEMORY >= 8:
128
+ model_config = {
129
+ 'name': 'small',
130
+ 'config': 'sam2_hiera_s.yaml',
131
+ 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt',
132
+ 'size': 155
133
+ }
134
+ else:
135
+ model_config = {
136
+ 'name': 'tiny',
137
+ 'config': 'sam2_hiera_t.yaml',
138
+ 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt',
139
+ 'size': 77
140
+ }
141
+
142
+ # Download model if needed
143
+ cache_dir = Path("/tmp/sam2_models")
144
+ cache_dir.mkdir(exist_ok=True)
145
+ model_path = cache_dir / f"sam2_{model_config['name']}.pt"
146
+
147
+ if not model_path.exists():
148
+ with st.spinner(f"Downloading SAM2 {model_config['name']} ({model_config['size']}MB)..."):
149
+ response = requests.get(model_config['url'], stream=True)
150
+ total_size = int(response.headers.get('content-length', 0))
151
+
152
+ progress_bar = st.progress(0)
153
+ with open(model_path, 'wb') as f:
154
+ downloaded = 0
155
+ for chunk in response.iter_content(chunk_size=8192):
156
+ f.write(chunk)
157
+ downloaded += len(chunk)
158
+ if total_size > 0:
159
+ progress_bar.progress(downloaded / total_size)
160
+ progress_bar.empty()
161
+
162
+ # Build model
163
+ sam2_model = build_sam2(
164
+ config_file=model_config['config'],
165
+ ckpt_path=str(model_path),
166
+ device=_self.device
167
+ )
168
+
169
+ # Use half precision on T4
170
+ if CUDA_AVAILABLE and 'T4' in GPU_NAME:
171
+ sam2_model = sam2_model.half()
172
+
173
+ predictor = SAM2ImagePredictor(sam2_model)
174
+
175
+ logger.info(f"βœ… SAM2 {model_config['name']} loaded successfully")
176
+ return predictor, True
177
+
178
+ except Exception as e:
179
+ logger.error(f"❌ SAM2 loading failed: {e}")
180
+ return None, False
181
+
182
+ @st.cache_resource
183
+ def load_matanyone(_self):
184
+ """Load MatAnyone model for edge refinement"""
185
  try:
 
186
  from matanyone import MatAnyoneModel, MatAnyonePredictor
187
 
188
  # Download model if needed
189
+ cache_dir = Path("/tmp/matanyone_models")
190
+ cache_dir.mkdir(exist_ok=True)
191
+ model_path = cache_dir / "matanyone_video.pth"
192
+
193
+ if not model_path.exists():
194
+ model_url = "https://huggingface.co/matanyone/matanyone-video/resolve/main/model.pth"
195
+
196
+ with st.spinner("Downloading MatAnyone model..."):
197
+ response = requests.get(model_url, stream=True)
198
+ total_size = int(response.headers.get('content-length', 0))
199
+
200
+ progress_bar = st.progress(0)
201
+ with open(model_path, 'wb') as f:
202
+ downloaded = 0
203
+ for chunk in response.iter_content(chunk_size=8192):
204
+ f.write(chunk)
205
+ downloaded += len(chunk)
206
+ if total_size > 0:
207
+ progress_bar.progress(downloaded / total_size)
208
+ progress_bar.empty()
209
 
210
  # Load model
211
  model = MatAnyoneModel.from_pretrained(
212
+ str(model_path),
213
+ device=_self.device,
214
+ fp16=(CUDA_AVAILABLE) # Use FP16 on GPU
215
  )
216
 
217
  # Create predictor
 
223
  )
224
 
225
  logger.info("βœ… MatAnyone loaded successfully")
226
+ return predictor, True
227
 
 
 
 
228
  except Exception as e:
229
+ logger.warning(f"⚠️ MatAnyone not available: {e}")
230
+ return None, False
231
 
232
+ def initialize(self):
233
+ """Initialize both models"""
234
+ if not self.sam2_loaded:
235
+ self.sam2_predictor, self.sam2_loaded = self.load_sam2()
 
 
236
 
237
+ if not self.matanyone_loaded:
238
+ self.matanyone_model, self.matanyone_loaded = self.load_matanyone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+ return self.sam2_loaded # At minimum need SAM2
241
 
242
+ def process_frame(self, frame: np.ndarray, use_temporal: bool = True) -> ProcessingResult:
243
+ """
244
+ Process single frame using SAM2 + MatAnyone combined
245
+
246
+ Pipeline:
247
+ 1. SAM2 generates initial segmentation
248
+ 2. Create trimap from SAM2 mask
249
+ 3. MatAnyone refines using trimap
250
+ 4. Return high-quality alpha matte
251
+ """
252
+
253
+ start_time = time.time()
254
+
255
+ if not self.initialize():
256
  return None
257
+
258
+ h, w = frame.shape[:2]
259
+
260
+ # ============================================
261
+ # STEP 1: SAM2 SEGMENTATION
262
+ # ============================================
263
+
264
+ # Set image for SAM2
265
+ self.sam2_predictor.set_image(frame)
266
+
267
+ # Generate point prompts with temporal consistency
268
+ if use_temporal and self.previous_result and self.previous_result.sam2_mask is not None:
269
+ # Use previous mask center
270
+ prev_mask = self.previous_result.sam2_mask
271
+ y_coords, x_coords = np.where(prev_mask > 0.5)
272
+
273
+ if len(y_coords) > 0:
274
+ center_y = int(np.mean(y_coords))
275
+ center_x = int(np.mean(x_coords))
276
+ # Focused points around previous center
277
+ point_coords = np.array([
278
+ [center_x, center_y],
279
+ [center_x - w//40, center_y],
280
+ [center_x + w//40, center_y],
281
+ [center_x, center_y - h//40],
282
+ [center_x, center_y + h//40]
283
+ ])
284
  else:
285
+ point_coords = self._get_default_points(w, h)
286
+ else:
287
+ point_coords = self._get_default_points(w, h)
288
+
289
+ point_labels = np.ones(len(point_coords))
290
+
291
+ # Get SAM2 predictions
292
+ masks, scores, logits = self.sam2_predictor.predict(
293
+ point_coords=point_coords,
294
+ point_labels=point_labels,
295
+ multimask_output=True,
296
+ return_logits=True
297
+ )
298
+
299
+ # Select best mask
300
+ best_idx = np.argmax(scores)
301
+ sam2_mask = masks[best_idx].astype(np.float32)
302
+
303
+ # Apply temporal smoothing to SAM2 mask
304
+ if use_temporal and self.previous_result and self.previous_result.sam2_mask is not None:
305
+ sam2_mask = 0.7 * sam2_mask + 0.3 * self.previous_result.sam2_mask
306
+ sam2_mask = np.clip(sam2_mask, 0, 1)
307
+
308
+ # ============================================
309
+ # STEP 2: CREATE TRIMAP FROM SAM2 MASK
310
+ # ============================================
311
+
312
+ trimap = self._create_trimap_from_mask(sam2_mask)
313
+
314
+ # ============================================
315
+ # STEP 3: MATANYONE REFINEMENT (if available)
316
+ # ============================================
317
+
318
+ if self.matanyone_loaded and self.matanyone_model:
319
+ try:
320
+ # Use MatAnyone for refinement
321
+ refined_alpha = self.matanyone_model.predict(
322
  image=frame,
323
  trimap=trimap,
324
+ previous_alpha=self.previous_result.alpha if use_temporal and self.previous_result else None,
325
+ temporal_weight=0.3 if use_temporal else 0.0
326
  )
327
+
328
+ # Additional refinement with guided filter
329
+ refined_alpha = cv2.ximgproc.guidedFilter(
330
+ guide=frame,
331
+ src=refined_alpha,
332
+ radius=3,
333
+ eps=1e-4
334
+ )
335
+
336
+ method = "SAM2+MatAnyone"
337
+
338
+ except Exception as e:
339
+ logger.warning(f"MatAnyone refinement failed, using SAM2 only: {e}")
340
+ refined_alpha = sam2_mask
341
+ method = "SAM2"
342
+ else:
343
+ # Use SAM2 mask with basic refinement
344
+ refined_alpha = sam2_mask
345
 
346
+ # Basic morphological refinement
347
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
348
+ refined_alpha = cv2.morphologyEx(refined_alpha, cv2.MORPH_CLOSE, kernel)
349
+ refined_alpha = cv2.GaussianBlur(refined_alpha, (5, 5), 0)
350
 
351
+ method = "SAM2"
352
+
353
+ # ============================================
354
+ # STEP 4: FINAL POST-PROCESSING
355
+ # ============================================
 
 
 
356
 
357
+ # Ensure valid range
358
+ refined_alpha = np.clip(refined_alpha, 0, 1)
 
359
 
360
+ # Create result
361
+ result = ProcessingResult(
362
+ alpha=refined_alpha,
363
+ sam2_mask=sam2_mask,
364
+ trimap=trimap,
365
+ method=method,
366
+ processing_time=time.time() - start_time
367
+ )
368
 
369
+ # Store for temporal consistency
370
+ self.previous_result = result
371
+ self.frame_count += 1
372
 
373
+ return result
374
 
375
+ def _get_default_points(self, w: int, h: int) -> np.ndarray:
376
+ """Get default point prompts for initial detection"""
377
+ return np.array([
378
+ [w//2, h//2], # Center
379
+ [w//2, h//3], # Head area
380
+ [w//2, 2*h//3], # Body area
381
+ [w//3, h//2], # Left
382
+ [2*w//3, h//2], # Right
383
+ [w//2, h//4], # Upper
384
+ [w//2, 3*h//4] # Lower
385
+ ])
386
+
387
+ def _create_trimap_from_mask(self, mask: np.ndarray, unknown_width: int = 20) -> np.ndarray:
388
+ """
389
+ Convert SAM2 mask to trimap for MatAnyone
390
+ 0: Background, 128: Unknown, 255: Foreground
391
+ """
392
+ trimap = np.zeros_like(mask, dtype=np.uint8)
393
 
394
+ # Threshold mask
395
+ binary_mask = (mask > 0.5).astype(np.uint8)
396
 
397
+ # Erode for definite foreground
398
+ kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
399
+ foreground = cv2.erode(binary_mask, kernel_small, iterations=2)
 
 
 
 
 
 
 
400
 
401
+ # Dilate for potential foreground
402
+ kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (unknown_width, unknown_width))
403
+ potential_fg = cv2.dilate(binary_mask, kernel_large, iterations=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
+ # Create trimap
406
+ trimap[potential_fg == 0] = 0 # Background
407
+ trimap[foreground == 1] = 255 # Foreground
408
+ trimap[(potential_fg == 1) & (foreground == 0)] = 128 # Unknown
409
+
410
+ return trimap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
  def reset(self):
413
+ """Reset temporal state for new video"""
414
+ self.previous_result = None
415
+ self.frame_count = 0
416
+ logger.info("Processor reset for new video")
417
 
418
  # ============================================
419
+ # FALLBACK: REMBG PROCESSOR
420
  # ============================================
421
 
422
  REMBG_AVAILABLE = False
 
429
  rembg_session = new_session('u2net_human_seg', providers=providers)
430
 
431
  # Warm up
432
+ dummy_img = Image.new('RGB', (128, 128), color='white')
433
  _ = remove(dummy_img, session=rembg_session)
434
 
435
  REMBG_AVAILABLE = True
436
+ logger.info("βœ… Rembg initialized as fallback")
437
 
438
  except Exception as e:
439
  logger.warning(f"⚠️ Rembg not available: {e}")
440
 
441
  def segment_with_rembg(frame):
442
+ """Fallback segmentation using Rembg"""
443
  if not REMBG_AVAILABLE:
444
  return None
445
 
446
  try:
447
  pil_image = Image.fromarray(frame)
448
+ output = remove(pil_image, session=rembg_session)
 
 
 
 
 
 
449
 
450
  output_array = np.array(output)
451
  if output_array.shape[2] == 4:
452
+ return output_array[:, :, 3].astype(np.float32) / 255.0
 
 
 
 
 
 
453
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  except Exception as e:
455
+ logger.error(f"Rembg failed: {e}")
456
  return None
457
 
458
  # ============================================
459
  # BACKGROUND UTILITIES
460
  # ============================================
461
 
462
+ def create_gradient_background(width=1280, height=720, color1=(70, 130, 180), color2=(255, 140, 90)):
 
 
 
 
 
 
 
 
 
 
 
463
  """Create gradient background"""
464
+ background = np.zeros((height, width, 3), dtype=np.uint8)
465
+
466
+ for y in range(height):
467
+ ratio = y / height
468
+ r = int(color1[0] * (1 - ratio) + color2[0] * ratio)
469
+ g = int(color1[1] * (1 - ratio) + color2[1] * ratio)
470
+ b = int(color1[2] * (1 - ratio) + color2[2] * ratio)
471
+ background[y, :] = [r, g, b]
472
+
473
  return background
474
 
475
+ def load_background_image(background_option):
476
+ """Load or create background based on option"""
477
+ if background_option.startswith("gradient:"):
478
+ gradient_type = background_option.split(":")[1]
479
+ if gradient_type == "blue":
480
+ return create_gradient_background(color1=(70, 130, 180), color2=(135, 206, 235))
481
+ elif gradient_type == "sunset":
482
+ return create_gradient_background(color1=(255, 94, 77), color2=(255, 154, 0))
483
+ else: # ocean
484
+ return create_gradient_background(color1=(0, 119, 190), color2=(0, 180, 216))
485
+ elif background_option.startswith("color:"):
486
+ color_name = background_option.split(":")[1]
487
+ colors = {"green": [0, 255, 0], "blue": [0, 0, 255], "white": [255, 255, 255]}
488
+ background = np.full((720, 1280, 3), colors.get(color_name, [255, 255, 255]), dtype=np.uint8)
489
+ return background
490
+ else:
491
+ try:
492
+ response = requests.get(background_option, timeout=10)
493
+ response.raise_for_status()
494
+ image = Image.open(BytesIO(response.content))
495
+ return np.array(image.convert('RGB'))
496
+ except:
497
+ return create_gradient_background()
498
+
499
+ def get_background_options():
500
+ """Background options for quick selection"""
501
  return {
502
+ "πŸŒ… Blue Gradient": "gradient:blue",
503
+ "πŸŒ‡ Sunset Gradient": "gradient:sunset",
504
+ "🌊 Ocean Gradient": "gradient:ocean",
505
+ "πŸ’š Green Screen": "color:green",
506
+ "πŸ’™ Blue Screen": "color:blue",
507
+ "βšͺ White Background": "color:white",
508
+ "🏒 Office": "https://images.unsplash.com/photo-1497366216548-37526070297c?w=1280&h=720&fit=crop",
509
+ "πŸŒ† City": "https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=1280&h=720&fit=crop",
510
+ "πŸ–οΈ Beach": "https://images.unsplash.com/photo-1507525428034-b723cf961d3e?w=1280&h=720&fit=crop",
511
+ "🌲 Nature": "https://images.unsplash.com/photo-1441974231531-c6227db76b6e?w=1280&h=720&fit=crop"
512
  }
513
 
514
  # ============================================
515
  # VIDEO PROCESSING PIPELINE
516
  # ============================================
517
 
518
+ # Initialize processor globally
519
+ processor = CombinedProcessor()
 
520
 
521
+ def process_video(video_path, background_option, speed_mode='balanced', progress_callback=None):
522
+ """
523
+ Process video with SAM2 + MatAnyone combined pipeline
524
+
525
+ Args:
526
+ video_path: Input video path
527
+ background_option: Background type/URL
528
+ speed_mode: 'ultra_fast', 'fast', 'balanced', 'quality'
529
+ progress_callback: Progress update function
530
+ """
531
  try:
532
  # Load background
533
+ background_image = load_background_image(background_option)
534
 
535
  # Open video
536
  cap = cv2.VideoCapture(video_path)
 
541
 
542
  logger.info(f"Processing video: {width}x{height}, {total_frames} frames, {fps} FPS")
543
 
544
+ # Determine frame skip based on speed mode
545
+ if speed_mode == 'ultra_fast':
546
+ frame_skip = 3 # Process every 3rd frame
547
+ interpolate = True
548
+ elif speed_mode == 'fast':
549
+ frame_skip = 2 # Process every 2nd frame
550
+ interpolate = True
551
+ elif speed_mode == 'balanced':
552
+ frame_skip = 1 # Process all frames
553
+ interpolate = False
554
+ else: # quality
555
+ frame_skip = 1
556
+ interpolate = False
557
+
558
  # Create output
559
  output_path = tempfile.mktemp(suffix='.mp4')
560
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
561
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
562
 
563
+ # Resize background once
564
  background_resized = cv2.resize(background_image, (width, height))
565
 
566
+ # Reset processor for new video
567
+ processor.reset()
 
568
 
569
  frame_count = 0
570
+ processed_count = 0
571
  processing_times = []
572
+ last_alpha = None
573
 
574
  while True:
575
  ret, frame = cap.read()
576
  if not ret:
577
  break
578
 
 
 
579
  # Convert BGR to RGB
580
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
581
 
582
+ # Process frame or use interpolation
583
+ if frame_count % frame_skip == 0:
584
+ start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
+ # Process with combined pipeline
587
+ result = processor.process_frame(frame_rgb, use_temporal=(processed_count > 0))
588
+
589
+ if result:
590
+ alpha = result.alpha
591
+ last_alpha = alpha
592
+ method_used = result.method
593
+ processing_times.append(result.processing_time)
594
+ else:
595
+ # Fallback to rembg
596
+ alpha = segment_with_rembg(frame_rgb)
597
+ if alpha is not None:
598
+ last_alpha = alpha
599
+ method_used = "Rembg"
600
+ else:
601
+ alpha = last_alpha if last_alpha is not None else np.ones((height, width), dtype=np.float32)
602
+ method_used = "Previous/Fallback"
603
+
604
+ processed_count += 1
605
 
 
 
606
  else:
607
+ # Use last alpha for skipped frames
608
+ alpha = last_alpha if last_alpha is not None else np.ones((height, width), dtype=np.float32)
609
+ method_used = "Interpolated"
610
+
611
+ # Apply alpha and composite
612
+ if alpha.ndim == 2:
613
+ alpha = np.expand_dims(alpha, axis=2)
614
+
615
+ # High-quality compositing
616
+ foreground = frame_rgb.astype(np.float32)
617
+ background = background_resized.astype(np.float32)
618
+
619
+ composite = foreground * alpha + background * (1 - alpha)
620
+ composite = np.clip(composite, 0, 255).astype(np.uint8)
621
 
622
  # Convert back to BGR
623
  composite_bgr = cv2.cvtColor(composite, cv2.COLOR_RGB2BGR)
624
  out.write(composite_bgr)
625
 
 
 
 
 
626
  frame_count += 1
627
 
628
  # Progress update
629
  if progress_callback:
630
  progress = frame_count / total_frames
631
+ if processing_times:
632
+ avg_time = np.mean(processing_times[-10:])
633
+ eta = avg_time * ((total_frames - frame_count) / frame_skip)
634
+ else:
635
+ eta = 0
636
  progress_callback(
637
  progress,
638
+ f"{method_used} | Frame {frame_count}/{total_frames} | ETA: {eta:.1f}s"
639
  )
640
 
641
  # Memory cleanup
642
+ if frame_count % 30 == 0 and CUDA_AVAILABLE:
643
  torch.cuda.empty_cache()
644
 
645
+ # Release resources
646
  cap.release()
647
  out.release()
648
 
 
650
  torch.cuda.empty_cache()
651
  gc.collect()
652
 
653
+ # Log statistics
654
+ if processing_times:
655
+ logger.info(f"βœ… Processing complete: {output_path}")
656
+ logger.info(f"Average processing time: {np.mean(processing_times):.3f}s per frame")
657
+ logger.info(f"Total processed frames: {processed_count}/{total_frames}")
658
 
659
  return output_path
660
 
 
668
 
669
  def main():
670
  st.set_page_config(
671
+ page_title="BackgroundFX - Lightning Fast",
672
+ page_icon="πŸš€",
673
  layout="wide",
674
  initial_sidebar_state="expanded"
675
  )
676
 
677
  # Header
678
+ st.title("πŸš€ BackgroundFX - Lightning-Fast Video Background Replacement")
679
+ st.markdown("**Professional quality in seconds, not minutes! Powered by SAM2 + MatAnyone**")
680
 
681
  # System Status
682
  col1, col2, col3, col4 = st.columns(4)
 
686
  st.success(f"πŸš€ GPU: {GPU_NAME}")
687
  st.caption(f"VRAM: {GPU_MEMORY:.1f}GB")
688
  else:
689
+ st.warning("πŸ’» CPU Mode")
690
 
691
  with col2:
692
  methods = []
693
+ if processor.sam2_loaded:
 
 
694
  methods.append("SAM2")
695
+ if processor.matanyone_loaded:
696
+ methods.append("MatAnyone")
697
  if REMBG_AVAILABLE:
698
  methods.append("Rembg")
699
+
700
+ if methods:
701
+ st.info(f"βœ… Ready: {', '.join(methods)}")
702
+ else:
703
+ st.warning("⏳ Loading models...")
704
 
705
  with col3:
706
  if CUDA_AVAILABLE:
 
710
  st.metric("Mode", "CPU")
711
 
712
  with col4:
713
+ # Speed indicator
714
+ st.metric("Status", "Ready" if processor.sam2_loaded else "Loading")
715
 
716
  # Sidebar
717
  with st.sidebar:
718
+ st.markdown("### ⚑ Speed Settings")
719
+
720
+ # Speed mode selection
721
+ speed_mode = st.select_slider(
722
+ "Processing Speed",
723
+ options=['ultra_fast', 'fast', 'balanced', 'quality'],
724
+ value='balanced',
725
+ format_func=lambda x: {
726
+ 'ultra_fast': '⚑⚑⚑ Ultra Fast (3x)',
727
+ 'fast': '⚑⚑ Fast (2x)',
728
+ 'balanced': '⚑ Balanced',
729
+ 'quality': '🎨 Quality'
730
+ }[x]
731
+ )
732
+
733
+ # Speed mode info
734
+ speed_info = {
735
+ 'ultra_fast': "Process every 3rd frame\n~5 sec for 10 sec video",
736
+ 'fast': "Process every 2nd frame\n~10 sec for 10 sec video",
737
+ 'balanced': "Process all frames\n~15 sec for 10 sec video",
738
+ 'quality': "Full processing\n~20 sec for 10 sec video"
739
  }
740
+ st.info(speed_info[speed_mode])
741
 
742
+ st.markdown("---")
 
 
 
 
 
743
 
744
+ # Processing info
745
+ st.markdown("### 🎯 Pipeline")
746
+
747
+ if processor.sam2_loaded and processor.matanyone_loaded:
748
+ st.success("SAM2 + MatAnyone Combined")
749
+ st.caption("Best quality mode active")
750
+ elif processor.sam2_loaded:
751
+ st.info("SAM2 Only")
752
+ st.caption("Good quality, fast processing")
753
+ else:
754
+ st.warning("Initializing...")
 
 
 
 
 
 
 
 
 
 
 
 
755
 
756
  st.markdown("---")
757
 
758
  # System info
759
+ st.markdown("### πŸ“Š System")
760
 
761
  if CUDA_AVAILABLE:
762
  allocated = torch.cuda.memory_allocated() / 1024**3
763
  reserved = torch.cuda.memory_reserved() / 1024**3
 
764
 
765
+ st.metric("Memory", f"{allocated:.1f}/{GPU_MEMORY:.0f} GB")
766
 
767
  usage_percent = (allocated / GPU_MEMORY) * 100 if GPU_MEMORY else 0
768
  st.progress(min(usage_percent / 100, 1.0))
769
 
770
+ # GPU details
771
  with st.expander("GPU Details"):
772
  st.code(f"""
773
  Device: {GPU_NAME}
774
  VRAM: {GPU_MEMORY:.1f} GB
775
+ Used: {allocated:.2f} GB
776
  Reserved: {reserved:.2f} GB
 
777
  PyTorch: {torch.__version__}
778
  CUDA: {torch.version.cuda if CUDA_AVAILABLE else 'N/A'}
779
  """)
 
 
780
 
781
  # Main content
782
  col1, col2 = st.columns(2)
 
787
  uploaded_video = st.file_uploader(
788
  "Upload your video",
789
  type=['mp4', 'avi', 'mov', 'mkv'],
790
+ help="Recommended: 10-30 seconds for best performance"
791
  )
792
 
793
  if uploaded_video:
 
797
  video_path = tmp_file.name
798
 
799
  st.video(uploaded_video)
800
+
801
+ # Get video info
802
+ cap = cv2.VideoCapture(video_path)
803
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
804
+ frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
805
+ duration = frames / fps if fps > 0 else 0
806
+ cap.release()
807
+
808
+ st.success(f"βœ… Ready: {duration:.1f}s @ {fps} FPS")
809
  else:
810
  video_path = None
811
 
812
  with col2:
813
+ st.markdown("### 🎨 Background")
814
 
815
+ # Quick background selection
816
+ backgrounds = get_background_options()
817
+ selected_bg = st.selectbox(
818
+ "Choose background",
819
  options=list(backgrounds.keys()),
820
  index=0
821
  )
822
 
823
+ background_option = backgrounds[selected_bg]
824
 
825
  # Preview
826
+ if background_option:
827
+ preview_bg = load_background_image(background_option)
828
+ preview_bg_resized = cv2.resize(preview_bg, (640, 360))
829
+ st.image(preview_bg_resized, caption=selected_bg, use_container_width=True)
 
830
 
831
  # Process button
832
  if video_path and st.button("πŸš€ Process Video", type="primary", use_container_width=True):
 
834
  # Progress tracking
835
  progress_bar = st.progress(0)
836
  status_text = st.empty()
837
+ time_text = st.empty()
838
 
839
  def update_progress(progress, message):
840
  progress_bar.progress(progress)
841
  status_text.text(message)
842
+ elapsed = time.time() - start_time
843
+ time_text.text(f"⏱️ Elapsed: {elapsed:.1f}s")
844
 
845
  # Process video
846
+ start_time = time.time()
847
+
848
+ result_path = process_video(
849
+ video_path,
850
+ background_option,
851
+ speed_mode=speed_mode,
852
+ progress_callback=update_progress
853
+ )
854
+
855
+ processing_time = time.time() - start_time
 
856
 
857
  if result_path and os.path.exists(result_path):
858
  # Success
859
+ status_text.text(f"βœ… Complete in {processing_time:.1f} seconds!")
860
+ time_text.text(f"πŸš€ Speed: {frames/processing_time:.1f} FPS")
861
 
862
  # Load result
863
  with open(result_path, 'rb') as f:
 
866
  st.markdown("### 🎬 Result")
867
  st.video(result_data)
868
 
869
+ # Download button
870
+ col1, col2, col3 = st.columns([1, 2, 1])
871
+ with col2:
872
+ st.download_button(
873
+ label="πŸ’Ύ Download Video",
874
+ data=result_data,
875
+ file_name=f"backgroundfx_{uploaded_video.name}",
876
+ mime="video/mp4",
877
+ use_container_width=True
878
+ )
879
+
880
+ # Stats
881
+ st.success(f"""
882
+ ✨ **Processing Complete!**
883
+ - Time: {processing_time:.1f} seconds
884
+ - Speed: {frames/processing_time:.1f} FPS
885
+ - Method: {processor.previous_result.method if processor.previous_result else 'Unknown'}
886
+ - Mode: {speed_mode.replace('_', ' ').title()}
887
+ """)
888
 
889
  # Cleanup
890
  os.unlink(result_path)
 
 
 
 
 
891
  else:
892
  st.error("❌ Processing failed! Please try again.")
893