MogensR commited on
Commit
82a2981
·
1 Parent(s): 5353980

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -536
app.py CHANGED
@@ -1,622 +1,416 @@
1
  #!/usr/bin/env python3
2
  """
3
- VideoBackgroundFX - SAM2 GPU-Optimized Video Background Replacement
4
- HuggingFace Space Deployment with L4 GPU Support
5
- Updated: 2025-08-13 - SAM2 Integration
6
  """
7
 
8
- import streamlit as st
9
- import cv2
10
- import numpy as np
11
- import tempfile
12
  import os
13
- from PIL import Image
14
- import requests
15
- from io import BytesIO
16
- import logging
17
- import base64
18
- import gc
19
  import torch
20
- import psutil
 
 
 
 
 
 
21
 
22
- # Configure logging
23
- logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
26
- # GPU Environment Setup
27
- def setup_gpu_environment():
28
- """Setup GPU environment for L4 optimization"""
29
- os.environ['OMP_NUM_THREADS'] = '8'
30
- os.environ['ORT_PROVIDERS'] = 'CUDAExecutionProvider,CPUExecutionProvider'
31
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
32
- os.environ['TORCH_CUDA_ARCH_LIST'] = '8.9' # L4 architecture
33
- os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
34
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
35
-
36
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  if torch.cuda.is_available():
38
- device_count = torch.cuda.device_count()
39
- gpu_name = torch.cuda.get_device_name(0)
40
  gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- logger.info(f"🚀 GPU: {gpu_name} ({gpu_memory:.1f}GB)")
43
-
44
- # Initialize CUDA context
45
- torch.cuda.init()
46
- torch.cuda.set_device(0)
47
 
48
- # Enable optimizations for L4
49
- torch.backends.cuda.matmul.allow_tf32 = True
50
- torch.backends.cudnn.allow_tf32 = True
51
- torch.backends.cudnn.benchmark = True
52
- torch.backends.cudnn.deterministic = False
53
 
54
- # Set memory fraction
55
- torch.cuda.set_per_process_memory_fraction(0.8)
 
 
 
 
 
 
 
 
56
 
57
- # Warm up GPU
58
- dummy = torch.randn(1024, 1024, device='cuda')
59
- dummy = dummy @ dummy.T
60
- del dummy
61
- torch.cuda.empty_cache()
62
 
63
- return True, gpu_name, gpu_memory
64
  else:
65
- logger.warning("⚠️ CUDA not available")
66
- return False, None, 0
67
- except Exception as e:
68
- logger.error(f"GPU setup failed: {e}")
69
- return False, None, 0
70
-
71
- # Initialize GPU
72
- CUDA_AVAILABLE, GPU_NAME, GPU_MEMORY = setup_gpu_environment()
73
-
74
- # SAM2 Integration
75
- try:
76
- from segment_anything import sam_model_registry, SamPredictor
77
- SAM_AVAILABLE = True
78
- logger.info("✅ SAM loaded successfully")
79
-
80
- # Initialize SAM with downloaded checkpoint
81
- sam_checkpoint = "sam_vit_h_4b8939.pth"
82
- model_type = "vit_h"
83
-
84
- if os.path.exists(sam_checkpoint) and CUDA_AVAILABLE:
85
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
86
- sam.to(device='cuda')
87
- sam_predictor = SamPredictor(sam)
88
- logger.info("✅ SAM2 GPU predictor initialized")
89
- else:
90
- sam_predictor = None
91
- if not os.path.exists(sam_checkpoint):
92
- logger.warning(f"⚠️ SAM checkpoint not found: {sam_checkpoint}")
93
 
94
- except ImportError as e:
95
- SAM_AVAILABLE = False
96
- sam_predictor = None
97
- logger.warning(f"⚠️ SAM not available: {e}")
98
-
99
- # Rembg with GPU optimization
100
- try:
101
- from rembg import remove, new_session
102
- import onnxruntime as ort
103
-
104
- REMBG_AVAILABLE = True
105
- logger.info("✅ Rembg loaded")
106
-
107
- if CUDA_AVAILABLE:
108
- providers = [
109
- ('CUDAExecutionProvider', {
110
- 'device_id': 0,
111
- 'arena_extend_strategy': 'kSameAsRequested',
112
- 'gpu_mem_limit': 20 * 1024 * 1024 * 1024, # 20GB for L4
113
- 'cudnn_conv_algo_search': 'HEURISTIC',
114
- }),
115
- 'CPUExecutionProvider'
116
- ]
117
 
118
- rembg_session = new_session('u2net_human_seg', providers=providers)
 
 
 
 
 
119
 
120
- # Warm up
121
- dummy_img = Image.new('RGB', (512, 512), color='white')
122
- with torch.cuda.amp.autocast():
123
- _ = remove(dummy_img, session=rembg_session)
124
 
125
- logger.info("✅ Rembg GPU session initialized")
126
- else:
127
- rembg_session = new_session('u2net_human_seg')
128
- logger.info("✅ Rembg CPU session initialized")
129
-
130
- except ImportError as e:
131
- REMBG_AVAILABLE = False
132
- rembg_session = None
133
- logger.warning(f"⚠️ Rembg not available: {e}")
134
-
135
- # OpenCV GPU check
136
- try:
137
- if cv2.cuda.getCudaEnabledDeviceCount() > 0:
138
- logger.info(f"✅ OpenCV CUDA devices: {cv2.cuda.getCudaEnabledDeviceCount()}")
139
- OPENCV_GPU = True
140
- else:
141
- OPENCV_GPU = False
142
- logger.warning("⚠️ OpenCV CUDA not available")
143
- except:
144
- OPENCV_GPU = False
145
- logger.warning("⚠️ OpenCV CUDA not available")
146
-
147
- # Memory management
148
- def optimize_memory():
149
- """Optimize memory usage"""
150
- if CUDA_AVAILABLE:
151
- torch.cuda.empty_cache()
152
- torch.cuda.synchronize()
153
- gc.collect()
154
-
155
- def get_memory_usage():
156
- """Get current memory usage"""
157
- stats = {}
158
- if CUDA_AVAILABLE:
159
- stats['gpu_allocated'] = torch.cuda.memory_allocated() / 1024**3
160
- stats['gpu_reserved'] = torch.cuda.memory_reserved() / 1024**3
161
- stats['gpu_free'] = GPU_MEMORY - stats['gpu_reserved']
162
- else:
163
- stats['gpu_allocated'] = 0
164
- stats['gpu_reserved'] = 0
165
- stats['gpu_free'] = 0
166
-
167
- # System RAM
168
- ram = psutil.virtual_memory()
169
- stats['ram_used'] = ram.used / 1024**3
170
- stats['ram_total'] = ram.total / 1024**3
171
- stats['ram_percent'] = ram.percent
172
-
173
- return stats
174
-
175
- # Background loading
176
- def load_background_image(background_url):
177
- """Load background image from URL"""
178
- try:
179
- if background_url == "default_brick":
180
- return create_default_background()
181
 
182
- response = requests.get(background_url)
183
- response.raise_for_status()
184
- image = Image.open(BytesIO(response.content))
185
- return np.array(image.convert('RGB'))
186
  except Exception as e:
187
- logger.error(f"Failed to load background image: {e}")
188
- return create_default_background()
 
 
 
 
189
 
190
- def create_default_background():
191
- """Create a default brick wall background"""
192
- background = np.zeros((720, 1280, 3), dtype=np.uint8)
193
- background[:, :] = [139, 69, 19] # Brown color
 
194
 
195
- # Add brick pattern
196
- for y in range(0, 720, 60):
197
- for x in range(0, 1280, 120):
198
- cv2.rectangle(background, (x, y), (x+115, y+55), (160, 82, 45), -1)
199
- cv2.rectangle(background, (x, y), (x+115, y+55), (101, 67, 33), 2)
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- return background
202
-
203
- def get_professional_backgrounds():
204
- """Get professional background collection"""
205
- return {
206
- "🏢 Modern Office": "https://images.unsplash.com/photo-1497366216548-37526070297c?w=1920&h=1080&fit=crop",
207
- "🌆 City Skyline": "https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=1920&h=1080&fit=crop",
208
- "🏖️ Tropical Beach": "https://images.unsplash.com/photo-1507525428034-b723cf961d3e?w=1920&h=1080&fit=crop",
209
- "🌲 Forest Path": "https://images.unsplash.com/photo-1441974231531-c6227db76b6e?w=1920&h=1080&fit=crop",
210
- "🎨 Abstract Blue": "https://images.unsplash.com/photo-1557683316-973673baf926?w=1920&h=1080&fit=crop",
211
- "🏔️ Mountain View": "https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=1920&h=1080&fit=crop",
212
- "🌅 Sunset Gradient": "https://images.unsplash.com/photo-1495616811223-4d98c6e9c869?w=1920&h=1080&fit=crop",
213
- "💼 Executive Suite": "https://images.unsplash.com/photo-1497366811353-6870744d04b2?w=1920&h=1080&fit=crop"
214
- }
215
-
216
- # SAM2 Segmentation
217
- def segment_person_sam2(frame):
218
- """SAM2 GPU-accelerated segmentation"""
219
- try:
220
- if SAM_AVAILABLE and sam_predictor and CUDA_AVAILABLE:
221
- # Set image for SAM
222
- sam_predictor.set_image(frame)
223
-
224
- # Get image center as prompt (simple heuristic)
225
- h, w = frame.shape[:2]
226
- input_point = np.array([[w//2, h//2]])
227
- input_label = np.array([1])
228
-
229
- # Predict mask
230
- with torch.no_grad():
231
- masks, scores, logits = sam_predictor.predict(
232
- point_coords=input_point,
233
- point_labels=input_label,
234
- multimask_output=True,
235
- )
236
-
237
- # Use best mask
238
- best_mask = masks[np.argmax(scores)]
239
- return best_mask.astype(np.float32)
240
 
241
- return None
242
- except Exception as e:
243
- logger.error(f"SAM2 segmentation failed: {e}")
244
- return None
245
-
246
- # Rembg Segmentation
247
- def segment_person_rembg(frame):
248
- """Rembg GPU-optimized segmentation"""
249
- try:
250
- if REMBG_AVAILABLE and rembg_session:
251
- pil_image = Image.fromarray(frame)
 
 
252
 
253
- if CUDA_AVAILABLE:
254
- with torch.cuda.amp.autocast():
255
- output = remove(
256
- pil_image,
257
- session=rembg_session,
258
- alpha_matting=True,
259
- alpha_matting_foreground_threshold=240,
260
- alpha_matting_background_threshold=10,
261
- alpha_matting_erode_size=10
262
- )
263
- else:
264
- output = remove(pil_image, session=rembg_session, alpha_matting=True)
265
 
266
- output_array = np.array(output)
267
- if output_array.shape[2] == 4:
268
- mask = output_array[:, :, 3].astype(np.float32) / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  else:
270
- mask = np.ones((frame.shape[0], frame.shape[1]), dtype=np.float32)
 
 
 
 
 
 
 
271
 
272
- return mask
273
- return None
274
- except Exception as e:
275
- logger.error(f"Rembg segmentation failed: {e}")
276
- return None
277
-
278
- # OpenCV GPU Segmentation
279
- def segment_person_opencv_gpu(frame):
280
- """OpenCV GPU segmentation"""
281
- try:
282
- if OPENCV_GPU:
283
- gpu_frame = cv2.cuda_GpuMat()
284
- gpu_frame.upload(frame)
285
 
286
- gpu_hsv = cv2.cuda.cvtColor(gpu_frame, cv2.COLOR_RGB2HSV)
 
 
 
 
 
 
287
 
288
- lower_skin = np.array([0, 20, 70])
289
- upper_skin = np.array([20, 255, 255])
 
290
 
291
- gpu_mask = cv2.cuda.inRange(gpu_hsv, lower_skin, upper_skin)
 
 
 
 
 
292
 
 
 
293
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
294
- gpu_mask = cv2.cuda.morphologyEx(gpu_mask, cv2.MORPH_CLOSE, kernel)
295
- gpu_mask = cv2.cuda.morphologyEx(gpu_mask, cv2.MORPH_OPEN, kernel)
296
 
297
- mask = gpu_mask.download()
 
298
 
299
- del gpu_frame, gpu_hsv, gpu_mask
 
 
300
 
301
- return mask.astype(float) / 255
302
- else:
303
- return segment_person_fallback_cpu(frame)
304
- except Exception as e:
305
- logger.error(f"OpenCV GPU segmentation failed: {e}")
306
- return segment_person_fallback_cpu(frame)
 
 
 
 
 
307
 
308
- def segment_person_fallback_cpu(frame):
309
- """CPU fallback segmentation"""
310
- try:
311
- hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV)
312
- lower_skin = np.array([0, 20, 70])
313
- upper_skin = np.array([20, 255, 255])
314
- mask = cv2.inRange(hsv, lower_skin, upper_skin)
315
-
316
- kernel = np.ones((5, 5), np.uint8)
317
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
318
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
319
-
320
- return mask.astype(float) / 255
321
- except Exception as e:
322
- logger.error(f"CPU fallback segmentation failed: {e}")
323
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- # Video Processing
326
- def process_video_gpu_optimized(video_path, background_url, progress_callback=None):
327
- """GPU-optimized video processing"""
 
 
 
 
328
  try:
329
- background_image = load_background_image(background_url)
330
-
331
  cap = cv2.VideoCapture(video_path)
332
 
 
333
  fps = int(cap.get(cv2.CAP_PROP_FPS))
334
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
335
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
336
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
337
 
338
- logger.info(f"Processing video: {width}x{height}, {total_frames} frames, {fps} FPS")
339
-
340
- output_path = tempfile.mktemp(suffix='.mp4')
341
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
342
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
343
 
 
344
  background_resized = cv2.resize(background_image, (width, height))
345
 
 
 
 
346
  frame_count = 0
347
- batch_size = 4 if CUDA_AVAILABLE else 1
348
- frame_batch = []
349
 
350
  while True:
351
  ret, frame = cap.read()
352
  if not ret:
353
- if frame_batch:
354
- processed_batch = process_frame_batch(frame_batch, background_resized)
355
- for processed_frame in processed_batch:
356
- out.write(processed_frame)
357
  break
358
 
 
359
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
360
- frame_batch.append(frame_rgb)
361
 
362
- if len(frame_batch) >= batch_size:
363
- processed_batch = process_frame_batch(frame_batch, background_resized)
 
 
 
 
 
364
 
365
- for processed_frame in processed_batch:
366
- out.write(processed_frame)
367
- frame_count += 1
368
-
369
- if progress_callback:
370
- progress = frame_count / total_frames
371
- memory_stats = get_memory_usage()
372
- progress_callback(
373
- progress,
374
- f"GPU Processing: {frame_count}/{total_frames} | "
375
- f"GPU: {memory_stats['gpu_allocated']:.1f}GB | "
376
- f"RAM: {memory_stats['ram_percent']:.1f}%"
377
- )
378
 
379
- frame_batch = []
380
- optimize_memory()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
  cap.release()
383
  out.release()
384
- optimize_memory()
385
 
386
- logger.info(f"Video processing complete: {output_path}")
387
  return output_path
388
 
389
  except Exception as e:
390
- logger.error(f"GPU video processing failed: {e}")
391
  return None
392
 
393
- def process_frame_batch(frame_batch, background_resized):
394
- """Process batch of frames"""
395
- processed_frames = []
396
-
397
- for frame in frame_batch:
398
- person_mask = None
399
- method_used = "None"
400
-
401
- # Try SAM2 first
402
- if SAM_AVAILABLE and CUDA_AVAILABLE:
403
- person_mask = segment_person_sam2(frame)
404
- if person_mask is not None:
405
- method_used = "SAM2-GPU"
406
-
407
- # Try Rembg
408
- if person_mask is None and REMBG_AVAILABLE:
409
- person_mask = segment_person_rembg(frame)
410
- if person_mask is not None:
411
- method_used = "Rembg-GPU"
412
-
413
- # Try OpenCV GPU
414
- if person_mask is None and OPENCV_GPU:
415
- person_mask = segment_person_opencv_gpu(frame)
416
- if person_mask is not None:
417
- method_used = "OpenCV-GPU"
418
-
419
- # CPU fallback
420
- if person_mask is None:
421
- person_mask = segment_person_fallback_cpu(frame)
422
- method_used = "CPU-Fallback"
423
-
424
- if person_mask is not None:
425
- if person_mask.ndim == 2:
426
- person_mask = np.expand_dims(person_mask, axis=2)
427
-
428
- final_frame = frame * person_mask + background_resized * (1 - person_mask)
429
- final_frame = final_frame.astype(np.uint8)
430
- else:
431
- final_frame = frame
432
-
433
- final_frame_bgr = cv2.cvtColor(final_frame, cv2.COLOR_RGB2BGR)
434
- processed_frames.append(final_frame_bgr)
435
-
436
- return processed_frames
437
 
438
- # Streamlit UI
439
  def main():
440
- st.set_page_config(
441
- page_title="VideoBackgroundFX - SAM2 GPU",
442
- page_icon="🚀",
443
- layout="wide",
444
- initial_sidebar_state="expanded"
445
- )
446
 
447
- st.title("🚀 VideoBackgroundFX - SAM2 GPU-Optimized")
448
- st.markdown("**High-performance video background replacement with SAM2 & GPU acceleration**")
449
-
450
- # GPU Status Dashboard
451
- col1, col2, col3, col4 = st.columns(4)
452
 
453
  with col1:
454
- if CUDA_AVAILABLE:
455
- st.success(f"🚀 GPU: {GPU_NAME}")
456
- st.caption(f"{GPU_MEMORY:.1f}GB VRAM")
457
- else:
458
- st.warning("⚠️ CPU Mode")
459
-
460
- with col2:
461
- if SAM_AVAILABLE and CUDA_AVAILABLE:
462
- st.success("✅ SAM2-GPU")
463
- elif REMBG_AVAILABLE:
464
- st.success("✅ Rembg-GPU")
465
- else:
466
- st.warning("⚠️ Basic Mode")
467
-
468
- with col3:
469
- if OPENCV_GPU:
470
- st.success("✅ OpenCV-GPU")
471
- else:
472
- st.info("ℹ️ OpenCV-CPU")
473
-
474
- with col4:
475
- memory_stats = get_memory_usage()
476
- if CUDA_AVAILABLE:
477
- st.metric("GPU Memory", f"{memory_stats['gpu_allocated']:.1f}GB")
478
  else:
479
- st.info("CPU Processing")
480
 
481
- # Sidebar monitoring
482
- with st.sidebar:
483
- st.markdown("### 🚀 System Performance")
484
-
485
- memory_stats = get_memory_usage()
486
-
487
- if CUDA_AVAILABLE:
488
- st.metric("GPU Allocated", f"{memory_stats['gpu_allocated']:.2f}GB")
489
- st.metric("GPU Reserved", f"{memory_stats['gpu_reserved']:.2f}GB")
490
- st.metric("GPU Free", f"{memory_stats['gpu_free']:.2f}GB")
491
-
492
- usage_percent = (memory_stats['gpu_reserved'] / GPU_MEMORY) * 100
493
- st.progress(usage_percent / 100)
494
- st.caption(f"{usage_percent:.1f}% GPU Memory Used")
495
-
496
- st.metric("RAM Used", f"{memory_stats['ram_used']:.1f}GB")
497
- st.metric("RAM Total", f"{memory_stats['ram_total']:.1f}GB")
498
- st.progress(memory_stats['ram_percent'] / 100)
499
- st.caption(f"{memory_stats['ram_percent']:.1f}% RAM Used")
500
-
501
- st.markdown("---")
502
- st.markdown("### 🛠️ Processing Methods")
503
- methods = []
504
-
505
- if SAM_AVAILABLE and CUDA_AVAILABLE:
506
- methods.append("🚀 SAM2-GPU (Ultra Precise)")
507
- if REMBG_AVAILABLE:
508
- methods.append("✅ Rembg-GPU (High Quality)")
509
- if OPENCV_GPU:
510
- methods.append("⚡ OpenCV-GPU (Fast)")
511
- methods.append("💻 CPU Fallback")
512
-
513
- for method in methods:
514
- st.markdown(method)
515
-
516
- # Main interface
517
- col1, col2 = st.columns(2)
518
-
519
- # Initialize session state
520
- if 'video_path' not in st.session_state:
521
- st.session_state.video_path = None
522
- if 'video_bytes' not in st.session_state:
523
- st.session_state.video_bytes = None
524
- if 'video_name' not in st.session_state:
525
- st.session_state.video_name = None
526
-
527
- with col1:
528
- st.markdown("### 📹 Upload Video")
529
- uploaded_video = st.file_uploader(
530
- "Choose a video file",
531
- type=['mp4', 'avi', 'mov', 'mkv'],
532
- help="Upload video for SAM2 GPU processing"
533
- )
534
-
535
  if uploaded_video:
536
- if st.session_state.video_name != uploaded_video.name:
537
- st.success(f"✅ Video uploaded: {uploaded_video.name}")
538
-
539
- video_bytes = uploaded_video.read()
540
-
541
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
542
- tmp_file.write(video_bytes)
543
- video_path = tmp_file.name
544
-
545
- st.session_state.video_path = video_path
546
- st.session_state.video_bytes = video_bytes
547
- st.session_state.video_name = uploaded_video.name
548
-
549
- if st.session_state.video_bytes is not None:
550
- st.video(st.session_state.video_bytes)
551
-
552
- elif st.session_state.video_path:
553
- st.success(f"✅ Video ready: {st.session_state.video_name}")
554
- st.video(st.session_state.video_bytes)
555
-
556
- with col2:
557
- st.markdown("### 🖼️ Background Selection")
558
-
559
- background_options = get_professional_backgrounds()
560
- selected_background = st.selectbox(
561
- "Choose background",
562
- options=list(background_options.keys()),
563
- index=0
564
- )
565
-
566
- background_url = background_options[selected_background]
567
-
568
- try:
569
- background_image = load_background_image(background_url)
570
- st.image(background_image, caption=f"Background: {selected_background}", use_container_width=True)
571
- except:
572
- st.error("Failed to load background image")
573
-
574
- # Processing button
575
- if (uploaded_video or st.session_state.video_path) and st.button("🚀 Process with SAM2", type="primary"):
576
- video_path = st.session_state.video_path
577
-
578
- if video_path and os.path.exists(video_path):
579
- progress_bar = st.progress(0)
580
- status_text = st.empty()
581
-
582
- def update_progress(progress, message):
583
- progress_bar.progress(progress)
584
- status_text.text(message)
585
 
586
- try:
587
- result_path = process_video_gpu_optimized(
588
- video_path,
589
- background_url,
590
- update_progress
591
- )
592
-
593
- if result_path and os.path.exists(result_path):
594
- status_text.text("✅ SAM2 processing complete!")
595
-
596
- with open(result_path, 'rb') as f:
597
- result_video = f.read()
598
-
599
- st.video(result_video)
600
-
601
- st.download_button(
602
- "💾 Download SAM2 Processed Video",
603
- data=result_video,
604
- file_name="sam2_backgroundfx_result.mp4",
605
- mime="video/mp4"
606
- )
607
-
608
- final_stats = get_memory_usage()
609
- st.success(f"🚀 SAM2 processing complete! GPU: {final_stats['gpu_allocated']:.2f}GB, RAM: {final_stats['ram_percent']:.1f}%")
610
-
611
- os.unlink(result_path)
612
- else:
613
- st.error("❌ SAM2 processing failed!")
614
-
615
- except Exception as e:
616
- st.error(f"❌ Error during SAM2 processing: {str(e)}")
617
- logger.error(f"SAM2 processing error: {e}")
618
- else:
619
- st.error("Video file not found. Please upload again.")
620
 
621
  if __name__ == "__main__":
622
- main()
 
1
  #!/usr/bin/env python3
2
  """
3
+ SAM2 (Segment Anything Model 2) for Video
4
+ Correct implementation with dynamic model loading
5
+ Optimized for video processing
6
  """
7
 
 
 
 
 
8
  import os
 
 
 
 
 
 
9
  import torch
10
+ import numpy as np
11
+ import streamlit as st
12
+ from pathlib import Path
13
+ import logging
14
+ import requests
15
+ from tqdm import tqdm
16
+ import cv2
17
 
 
 
18
  logger = logging.getLogger(__name__)
19
 
20
+ # ============================================
21
+ # SAM2 DYNAMIC LOADER FOR VIDEO
22
+ # ============================================
23
+
24
+ @st.cache_resource(show_spinner=False)
25
+ def load_sam2_model_dynamic():
26
+ """
27
+ Download and load SAM2 model dynamically
28
+ SAM2 is specifically designed for video segmentation
29
+ """
30
  try:
31
+ # Import SAM2 (not SAM1!)
32
+ from sam2.build_sam import build_sam2
33
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
34
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
35
+
36
+ # SAM2 Model URLs (these are the NEW video-optimized models)
37
+ MODEL_URLS = {
38
+ 'sam2_hiera_large': {
39
+ 'config': 'sam2_hiera_l.yaml',
40
+ 'checkpoint': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt',
41
+ 'size': '897MB',
42
+ 'quality': 'Best for video'
43
+ },
44
+ 'sam2_hiera_base_plus': {
45
+ 'config': 'sam2_hiera_b+.yaml',
46
+ 'checkpoint': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt',
47
+ 'size': '323MB',
48
+ 'quality': 'Balanced'
49
+ },
50
+ 'sam2_hiera_small': {
51
+ 'config': 'sam2_hiera_s.yaml',
52
+ 'checkpoint': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt',
53
+ 'size': '155MB',
54
+ 'quality': 'Fast'
55
+ },
56
+ 'sam2_hiera_tiny': {
57
+ 'config': 'sam2_hiera_t.yaml',
58
+ 'checkpoint': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt',
59
+ 'size': '77MB',
60
+ 'quality': 'Fastest'
61
+ }
62
+ }
63
+
64
+ # Choose model based on GPU
65
  if torch.cuda.is_available():
 
 
66
  gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
67
+ if gpu_memory > 20: # L4 has 24GB
68
+ model_name = 'sam2_hiera_large'
69
+ elif gpu_memory > 10:
70
+ model_name = 'sam2_hiera_base_plus'
71
+ elif gpu_memory > 6:
72
+ model_name = 'sam2_hiera_small'
73
+ else:
74
+ model_name = 'sam2_hiera_tiny'
75
+ else:
76
+ model_name = 'sam2_hiera_tiny' # CPU = smallest
77
+
78
+ logger.info(f"Selected SAM2 model: {model_name} ({MODEL_URLS[model_name]['quality']})")
79
+
80
+ # Setup cache directory
81
+ cache_dir = Path("/tmp/sam2_models")
82
+ cache_dir.mkdir(exist_ok=True)
83
+
84
+ model_path = cache_dir / f"{model_name}.pt"
85
+ config_name = MODEL_URLS[model_name]['config']
86
+
87
+ # Download if not cached
88
+ if not model_path.exists():
89
+ logger.info(f"Downloading SAM2 {model_name} ({MODEL_URLS[model_name]['size']})...")
90
 
91
+ # Show progress in Streamlit
92
+ progress_text = st.empty()
93
+ progress_bar = st.progress(0)
 
 
94
 
95
+ # Download with progress
96
+ response = requests.get(MODEL_URLS[model_name]['checkpoint'], stream=True)
97
+ total_size = int(response.headers.get('content-length', 0))
 
 
98
 
99
+ with open(model_path, 'wb') as f:
100
+ downloaded = 0
101
+ for chunk in response.iter_content(chunk_size=8192):
102
+ f.write(chunk)
103
+ downloaded += len(chunk)
104
+
105
+ if total_size > 0:
106
+ progress = downloaded / total_size
107
+ progress_bar.progress(progress)
108
+ progress_text.text(f"Downloading SAM2: {downloaded/(1024**2):.1f}MB / {total_size/(1024**2):.1f}MB")
109
 
110
+ progress_text.empty()
111
+ progress_bar.empty()
 
 
 
112
 
113
+ logger.info(f"✅ SAM2 model downloaded to {model_path}")
114
  else:
115
+ logger.info(f" Using cached SAM2 model from {model_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ # Build SAM2 model
118
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ sam2_model = build_sam2(
121
+ config_file=config_name,
122
+ ckpt_path=str(model_path),
123
+ device=device,
124
+ apply_postprocessing=True
125
+ )
126
 
127
+ # Create predictor for frame-by-frame processing
128
+ predictor = SAM2ImagePredictor(sam2_model)
 
 
129
 
130
+ logger.info(f"✅ SAM2 loaded successfully on {device}")
131
+ return predictor, model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ except ImportError as e:
134
+ logger.error(f"SAM2 not installed. Install with: pip install sam-2")
135
+ return None, None
 
136
  except Exception as e:
137
+ logger.error(f"Failed to load SAM2 model: {e}")
138
+ return None, None
139
+
140
+ # ============================================
141
+ # SAM2 VIDEO PROCESSOR
142
+ # ============================================
143
 
144
+ class SAM2VideoProcessor:
145
+ """
146
+ SAM2 optimized for video processing
147
+ Handles temporal consistency across frames
148
+ """
149
 
150
+ def __init__(self):
151
+ self.predictor = None
152
+ self.model_name = None
153
+ self.loaded = False
154
+ self.previous_mask = None
155
+ self.frame_count = 0
156
+
157
+ def load_model(self):
158
+ """Load SAM2 model if not already loaded"""
159
+ if not self.loaded:
160
+ with st.spinner("🎬 Loading SAM2 Video Model..."):
161
+ self.predictor, self.model_name = load_sam2_model_dynamic()
162
+ self.loaded = True
163
+ if self.predictor:
164
+ logger.info(f"SAM2 Video Processor ready with {self.model_name}")
165
+ return self.predictor is not None
166
 
167
+ def segment_frame(self, frame, use_previous=True):
168
+ """
169
+ Segment a single frame with temporal consistency
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ Args:
172
+ frame: Input frame (H, W, 3) numpy array
173
+ use_previous: Use previous frame's mask for consistency
174
+
175
+ Returns:
176
+ mask: Segmentation mask (H, W) float32
177
+ """
178
+ if not self.load_model():
179
+ return None
180
+
181
+ try:
182
+ # Set the image
183
+ self.predictor.set_image(frame)
184
 
185
+ h, w = frame.shape[:2]
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ # Generate point prompts
188
+ if use_previous and self.previous_mask is not None:
189
+ # Use previous mask to guide current segmentation
190
+ # Find center of mass of previous mask
191
+ y_coords, x_coords = np.where(self.previous_mask > 0.5)
192
+ if len(y_coords) > 0:
193
+ center_y = int(np.mean(y_coords))
194
+ center_x = int(np.mean(x_coords))
195
+
196
+ # Add points around previous center
197
+ point_coords = np.array([
198
+ [center_x, center_y],
199
+ [center_x, center_y - h//8], # Above
200
+ [center_x, center_y + h//8], # Below
201
+ ])
202
+ else:
203
+ # Fallback to center points
204
+ point_coords = np.array([
205
+ [w//2, h//2],
206
+ [w//2, h//3],
207
+ [w//2, 2*h//3]
208
+ ])
209
  else:
210
+ # Initial frame - use center points
211
+ point_coords = np.array([
212
+ [w//2, h//2], # Center
213
+ [w//2, h//3], # Upper (head)
214
+ [w//2, 2*h//3], # Lower (body)
215
+ [w//3, h//2], # Left
216
+ [2*w//3, h//2], # Right
217
+ ])
218
 
219
+ point_labels = np.ones(len(point_coords)) # All foreground
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
+ # Generate masks with SAM2
222
+ masks, scores, logits = self.predictor.predict(
223
+ point_coords=point_coords,
224
+ point_labels=point_labels,
225
+ multimask_output=True,
226
+ return_logits=True
227
+ )
228
 
229
+ # Select best mask
230
+ best_idx = np.argmax(scores)
231
+ mask = masks[best_idx].astype(np.float32)
232
 
233
+ # Apply temporal smoothing if we have previous mask
234
+ if use_previous and self.previous_mask is not None:
235
+ # Blend with previous mask for temporal consistency
236
+ alpha = 0.3 # Smoothing factor
237
+ mask = (1 - alpha) * mask + alpha * self.previous_mask
238
+ mask = np.clip(mask, 0, 1)
239
 
240
+ # Post-processing for better quality
241
+ # Morphological operations
242
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
243
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
244
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
245
 
246
+ # Gaussian blur for smooth edges
247
+ mask = cv2.GaussianBlur(mask, (7, 7), 0)
248
 
249
+ # Store for next frame
250
+ self.previous_mask = mask.copy()
251
+ self.frame_count += 1
252
 
253
+ return mask
254
+
255
+ except Exception as e:
256
+ logger.error(f"SAM2 segmentation failed: {e}")
257
+ return None
258
+
259
+ def reset(self):
260
+ """Reset temporal state for new video"""
261
+ self.previous_mask = None
262
+ self.frame_count = 0
263
+ logger.info("SAM2 Video Processor reset for new video")
264
 
265
+ # ============================================
266
+ # LAZY LOADER FOR SAM2
267
+ # ============================================
268
+
269
+ class SAM2LazyLoader:
270
+ """
271
+ Lazy loading for SAM2 - only loads when needed
272
+ """
273
+ def __init__(self):
274
+ self.processor = SAM2VideoProcessor()
275
+
276
+ def segment_frame(self, frame, use_temporal=True):
277
+ """
278
+ Segment frame with lazy loading
279
+ Model loads on first call
280
+ """
281
+ return self.processor.segment_frame(frame, use_previous=use_temporal)
282
+
283
+ def reset(self):
284
+ """Reset for new video"""
285
+ self.processor.reset()
286
+
287
+ @property
288
+ def is_available(self):
289
+ """Check if SAM2 can be loaded"""
290
+ try:
291
+ import sam2
292
+ return True
293
+ except ImportError:
294
+ return False
295
+
296
+ @property
297
+ def is_loaded(self):
298
+ """Check if model is already loaded"""
299
+ return self.processor.loaded
300
+
301
+ # ============================================
302
+ # INTEGRATION WITH VIDEO PROCESSING
303
+ # ============================================
304
 
305
+ # Global SAM2 instance
306
+ SAM2_VIDEO = SAM2LazyLoader()
307
+
308
+ def process_video_with_sam2(video_path, background_image, progress_callback=None):
309
+ """
310
+ Process video using SAM2 with temporal consistency
311
+ """
312
  try:
313
+ # Open video
 
314
  cap = cv2.VideoCapture(video_path)
315
 
316
+ # Get video properties
317
  fps = int(cap.get(cv2.CAP_PROP_FPS))
318
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
319
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
320
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
321
 
322
+ # Create output writer
323
+ output_path = '/tmp/output_sam2.mp4'
 
324
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
325
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
326
 
327
+ # Resize background
328
  background_resized = cv2.resize(background_image, (width, height))
329
 
330
+ # Reset SAM2 for new video
331
+ SAM2_VIDEO.reset()
332
+
333
  frame_count = 0
 
 
334
 
335
  while True:
336
  ret, frame = cap.read()
337
  if not ret:
 
 
 
 
338
  break
339
 
340
+ # Convert BGR to RGB
341
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
342
 
343
+ # Segment with SAM2 (with temporal consistency)
344
+ mask = SAM2_VIDEO.segment_frame(frame_rgb, use_temporal=(frame_count > 0))
345
+
346
+ if mask is not None:
347
+ # Apply mask
348
+ if mask.ndim == 2:
349
+ mask = np.expand_dims(mask, axis=2)
350
 
351
+ # Composite
352
+ composite = frame_rgb * mask + background_resized * (1 - mask)
353
+ composite = composite.astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
354
 
355
+ # Convert back to BGR
356
+ composite_bgr = cv2.cvtColor(composite, cv2.COLOR_RGB2BGR)
357
+ else:
358
+ composite_bgr = frame
359
+
360
+ out.write(composite_bgr)
361
+ frame_count += 1
362
+
363
+ # Progress callback
364
+ if progress_callback:
365
+ progress = frame_count / total_frames
366
+ progress_callback(progress, f"SAM2 Processing: {frame_count}/{total_frames}")
367
+
368
+ # Memory cleanup every 50 frames
369
+ if frame_count % 50 == 0 and torch.cuda.is_available():
370
+ torch.cuda.empty_cache()
371
 
372
  cap.release()
373
  out.release()
 
374
 
375
+ logger.info(f" SAM2 video processing complete: {frame_count} frames")
376
  return output_path
377
 
378
  except Exception as e:
379
+ logger.error(f"SAM2 video processing failed: {e}")
380
  return None
381
 
382
+ # ============================================
383
+ # EXAMPLE USAGE
384
+ # ============================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
 
386
  def main():
387
+ st.title("🎬 Video Background Replacer with SAM2")
 
 
 
 
 
388
 
389
+ # Status display
390
+ col1, col2, col3 = st.columns(3)
 
 
 
391
 
392
  with col1:
393
+ if SAM2_VIDEO.is_available:
394
+ if SAM2_VIDEO.is_loaded:
395
+ st.success(" SAM2 Loaded")
396
+ else:
397
+ st.info("🎯 SAM2 Ready (loads on demand)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  else:
399
+ st.error(" SAM2 not installed")
400
 
401
+ # Process button
402
+ if st.button("Process with SAM2"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  if uploaded_video:
404
+ # This triggers model download on first use
405
+ result = process_video_with_sam2(
406
+ video_path,
407
+ background_image,
408
+ progress_callback=update_progress
409
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
+ if result:
412
+ st.success("✅ Video processed with SAM2!")
413
+ st.video(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  if __name__ == "__main__":
416
+ main()