MogensR commited on
Commit
8f6e77a
·
1 Parent(s): 10e40ec
Files changed (1) hide show
  1. app.py +291 -784
app.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- 🎬 BackgroundFX Pro - Complete Fixed Version
4
- Professional video background replacement with SAM2 segmentation
5
  """
6
 
7
  import os
@@ -15,6 +15,7 @@
15
  import requests
16
  import tempfile
17
  import subprocess
 
18
  import numpy as np
19
  import io
20
  from PIL import Image
@@ -24,7 +25,12 @@
24
 
25
  import gradio as gr
26
 
27
- # Configure logging first
 
 
 
 
 
28
  logging.basicConfig(
29
  level=logging.INFO,
30
  format='%(asctime)s - %(levelname)s - %(message)s'
@@ -36,495 +42,212 @@
36
  SKLEARN_AVAILABLE = True
37
  except ImportError:
38
  SKLEARN_AVAILABLE = False
39
- logger.warning("⚠️ sklearn not available, using fallback color detection")
40
-
41
- # ===============================================================================
42
- # SYSTEM CONFIGURATION
43
- # ===============================================================================
44
-
45
- def setup_environment():
46
- """Configure environment and check system capabilities"""
47
- logger.info("Starting BackgroundFX Pro with SAM2...")
48
-
49
- # GPU detection
50
- if torch.cuda.is_available():
51
- device = torch.device("cuda")
52
- gpu_name = torch.cuda.get_device_name(0)
53
- gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
54
- logger.info(f"Device: cuda")
55
- logger.info(f"GPU: {gpu_name} ({gpu_memory:.1f}GB)")
56
-
57
- # GPU type detection for model selection
58
- gpu_type = "T4" if "T4" in gpu_name else "other"
59
- model_size = "small" if gpu_type == "T4" else "base_plus"
60
- else:
61
- device = torch.device("cpu")
62
- logger.info("Device: cpu")
63
- gpu_name = "CPU"
64
- gpu_memory = 0
65
- model_size = "small"
66
-
67
- return device, gpu_name, gpu_memory, model_size
68
-
69
- # Initialize system
70
- DEVICE, GPU_NAME, GPU_MEMORY, MODEL_SIZE = setup_environment()
71
-
72
- # ===============================================================================
73
- # SAM2 INTEGRATION
74
- # ===============================================================================
75
-
76
- def check_sam2_availability():
77
- """Check if SAM2 is available"""
78
- try:
79
- import sam2
80
- logger.info("✅ SAM2 is available")
81
- return True
82
- except ImportError:
83
- logger.warning("❌ SAM2 not available - using fallback methods")
84
- return False
85
-
86
- def check_matanyone_availability():
87
- """Check if MatAnyone is available"""
88
- try:
89
- from matanyone import InferenceCore
90
- logger.info("✅ MatAnyone is available")
91
- return True
92
- except ImportError:
93
- logger.warning("❌ MatAnyone not available - using fallback methods")
94
- return False
95
-
96
- SAM2_AVAILABLE = check_sam2_availability()
97
- MATANYONE_AVAILABLE = check_matanyone_availability()
98
-
99
- # Global model instances
100
- matanyone_processor = None
101
-
102
- class SAM2Segmenter:
103
- """SAM2 + MatAnyone professional video segmentation"""
104
-
105
- def __init__(self):
106
- self.sam2_model = None
107
- self.sam2_predictor = None
108
- self.matanyone_processor = None
109
-
110
- def load_models(self):
111
- """Load both SAM2 and MatAnyone models"""
112
- sam2_loaded = self.load_sam2_model()
113
- matanyone_loaded = self.load_matanyone_model()
114
-
115
- if sam2_loaded and matanyone_loaded:
116
- logger.info("✅ SAM2 + MatAnyone professional pipeline ready")
117
- return True
118
- elif sam2_loaded:
119
- logger.info("✅ SAM2 loaded, MatAnyone unavailable - using SAM2 + OpenCV")
120
- return True
121
- else:
122
- logger.warning("⚠️ Both SAM2 and MatAnyone unavailable - using fallback")
123
- return False
124
-
125
- def load_sam2_model(self):
126
- """Load SAM2 model with auto-download"""
127
- if not SAM2_AVAILABLE:
128
- return False
129
-
130
- try:
131
- # Ensure checkpoints directory exists
132
- os.makedirs("checkpoints", exist_ok=True)
133
-
134
- if MODEL_SIZE == "small":
135
- from sam2.build_sam import build_sam2_video_predictor
136
- checkpoint_file = "checkpoints/sam2_hiera_small.pt"
137
- config = "sam2_hiera_s.yaml"
138
- checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"
139
- else:
140
- from sam2.build_sam import build_sam2_video_predictor
141
- checkpoint_file = "checkpoints/sam2_hiera_base_plus.pt"
142
- config = "sam2_hiera_b+.yaml"
143
- checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"
144
-
145
- # Download checkpoint if it doesn't exist
146
- if not os.path.exists(checkpoint_file):
147
- logger.info(f"📥 Downloading SAM2 checkpoint: {MODEL_SIZE}")
148
- self.download_checkpoint(checkpoint_url, checkpoint_file)
149
-
150
- # Also need config file
151
- config_file = f"checkpoints/{config}"
152
- if not os.path.exists(config_file):
153
- logger.info(f"📥 Downloading SAM2 config: {config}")
154
- config_url = f"https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2/configs/{config}"
155
- self.download_checkpoint(config_url, config_file)
156
-
157
- self.sam2_predictor = build_sam2_video_predictor(config_file, checkpoint_file, device=DEVICE)
158
- logger.info(f"✅ SAM2 model loaded: {MODEL_SIZE}")
159
- return True
160
-
161
- except Exception as e:
162
- logger.error(f"❌ Failed to load SAM2 model: {e}")
163
- return False
164
-
165
- def load_matanyone_model(self):
166
- """Load MatAnyone model"""
167
- if not MATANYONE_AVAILABLE:
168
- return False
169
-
170
- try:
171
- from matanyone import InferenceCore
172
- self.matanyone_processor = InferenceCore("PeiqingYang/MatAnyone")
173
- logger.info("✅ MatAnyone processor loaded")
174
- return True
175
- except Exception as e:
176
- logger.error(f"❌ Failed to load MatAnyone: {e}")
177
- return False
178
-
179
- def download_checkpoint(self, url: str, filepath: str):
180
- """Download SAM2 checkpoint with progress"""
181
- try:
182
- response = requests.get(url, stream=True)
183
- response.raise_for_status()
184
-
185
- total_size = int(response.headers.get('content-length', 0))
186
- block_size = 8192
187
- downloaded = 0
188
-
189
- with open(filepath, 'wb') as f:
190
- for chunk in response.iter_content(chunk_size=block_size):
191
- if chunk:
192
- f.write(chunk)
193
- downloaded += len(chunk)
194
- if total_size > 0:
195
- progress = (downloaded / total_size) * 100
196
- if downloaded % (block_size * 100) == 0: # Log every ~800KB
197
- logger.info(f"📥 Download progress: {progress:.1f}%")
198
-
199
- logger.info(f"✅ Downloaded: {filepath}")
200
-
201
- except Exception as e:
202
- logger.error(f"❌ Download failed: {e}")
203
- raise
204
-
205
- def segment_video(self, video_path: str, output_path: str) -> Tuple[bool, str]:
206
- """Professional SAM2 + MatAnyone video segmentation"""
207
- try:
208
- if not self.sam2_predictor and not self.load_models():
209
- logger.warning("⚠️ Professional models unavailable, using fallback")
210
- return self.fallback_segmentation(video_path, output_path)
211
-
212
- if self.sam2_predictor and self.matanyone_processor:
213
- # Full professional pipeline: SAM2 mask + MatAnyone processing
214
- return self.professional_sam2_matanyone_pipeline(video_path, output_path)
215
- elif self.sam2_predictor:
216
- # SAM2 mask + OpenCV replacement
217
- return self.sam2_opencv_pipeline(video_path, output_path)
218
- else:
219
- # Fallback
220
- return self.fallback_segmentation(video_path, output_path)
221
-
222
- except Exception as e:
223
- logger.error(f"❌ Error in video segmentation: {e}")
224
- # Try fallback method
225
- logger.warning("⚠️ Trying fallback segmentation method...")
226
- return self.fallback_segmentation(video_path, output_path)
227
-
228
- def professional_sam2_matanyone_pipeline(self, video_path: str, output_path: str) -> Tuple[bool, str]:
229
- """Professional SAM2 + MatAnyone pipeline"""
230
- try:
231
- logger.info("🎬 Using PROFESSIONAL SAM2 + MatAnyone pipeline")
232
-
233
- # Step 1: Extract first frame for SAM2 analysis
234
- first_frame_path = self.extract_first_frame(video_path)
235
- if not first_frame_path:
236
- raise Exception("Failed to extract first frame")
237
-
238
- # Step 2: Generate high-quality mask with SAM2
239
- mask_path = self.generate_sam2_mask(first_frame_path)
240
- if not mask_path:
241
- raise Exception("Failed to generate SAM2 mask")
242
-
243
- # Step 3: Process with MatAnyone
244
- logger.info("⚡ Processing video with MatAnyone professional matting...")
245
-
246
- # Create temp directory for MatAnyone output
247
- temp_dir = tempfile.mkdtemp()
248
-
249
- try:
250
- # Use MatAnyone for professional video matting
251
- foreground_path, alpha_path = self.matanyone_processor.process_video(
252
- input_path=video_path,
253
- mask_path=mask_path,
254
- output_path=temp_dir
255
- )
256
-
257
- # For now, copy foreground to output (can add background compositing later)
258
- shutil.copy2(foreground_path, output_path)
259
-
260
- logger.info("✅ Professional SAM2 + MatAnyone processing completed")
261
- return True, "Professional SAM2 + MatAnyone segmentation completed successfully"
262
-
263
- finally:
264
- # Cleanup temp files
265
- try:
266
- shutil.rmtree(temp_dir)
267
- os.unlink(first_frame_path)
268
- os.unlink(mask_path)
269
- except:
270
- pass
271
-
272
- except Exception as e:
273
- logger.error(f"❌ Professional pipeline failed: {e}")
274
- return False, f"Professional pipeline error: {str(e)}"
275
-
276
- def sam2_opencv_pipeline(self, video_path: str, output_path: str) -> Tuple[bool, str]:
277
- """SAM2 mask + OpenCV replacement pipeline"""
278
  try:
279
- logger.info("🎯 Using SAM2 + OpenCV pipeline")
280
-
281
- # Extract first frame
282
- first_frame_path = self.extract_first_frame(video_path)
283
- if not first_frame_path:
284
- raise Exception("Failed to extract first frame")
285
-
286
- # Generate SAM2 mask
287
- mask_path = self.generate_sam2_mask(first_frame_path)
288
- if not mask_path:
289
- raise Exception("Failed to generate SAM2 mask")
290
-
291
- # Apply mask to video using OpenCV
292
- return self.apply_sam2_mask_to_video(video_path, mask_path, output_path)
293
-
294
  except Exception as e:
295
- logger.error(f"SAM2 + OpenCV pipeline failed: {e}")
296
- return False, f"SAM2 + OpenCV error: {str(e)}"
297
-
298
- def extract_first_frame(self, video_path: str) -> Optional[str]:
299
- """Extract first frame for SAM2 processing"""
 
 
 
300
  try:
301
- cap = cv2.VideoCapture(video_path)
302
- ret, frame = cap.read()
303
- cap.release()
304
-
305
- if not ret:
306
- return None
307
-
308
- # Save first frame
309
- with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
310
- cv2.imwrite(tmp.name, frame)
311
- return tmp.name
312
-
313
- except Exception as e:
314
- logger.error(f"Error extracting first frame: {e}")
315
- return None
316
-
317
- def generate_sam2_mask(self, frame_path: str) -> Optional[str]:
318
- """Generate person mask using SAM2"""
319
- try:
320
- if not self.sam2_predictor:
321
- return None
322
-
323
- # Load image
324
- image = cv2.imread(frame_path)
325
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
326
-
327
- # Set image for SAM2
328
- self.sam2_predictor.set_image(image_rgb)
329
-
330
- # Auto-detect person in center
331
- height, width = image_rgb.shape[:2]
332
- center_point = np.array([[width//2, height//2]])
333
- point_labels = np.array([1]) # 1 = foreground
334
-
335
- # Generate mask
336
- masks, scores, logits = self.sam2_predictor.predict(
337
- point_coords=center_point,
338
- point_labels=point_labels,
339
- multimask_output=False
340
  )
341
-
342
- # Save mask
343
- mask = masks[0].astype(np.uint8) * 255
344
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
345
- cv2.imwrite(tmp.name, mask)
346
- return tmp.name
347
-
348
  except Exception as e:
349
- logger.error(f"Error generating SAM2 mask: {e}")
350
- return None
351
-
352
- def apply_sam2_mask_to_video(self, video_path: str, mask_path: str, output_path: str) -> Tuple[bool, str]:
353
- """Apply SAM2 mask to video using OpenCV"""
354
- try:
355
- # Load mask
356
- mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
357
- if mask is None:
358
- raise Exception("Could not load mask")
359
-
360
- # Process video
361
- cap = cv2.VideoCapture(video_path)
362
- if not cap.isOpened():
363
- raise Exception("Could not open video")
364
-
365
- # Get video properties
366
- fps = int(cap.get(cv2.CAP_PROP_FPS))
367
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
368
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
369
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
370
-
371
- # Resize mask to match video
372
- mask_resized = cv2.resize(mask, (width, height))
373
-
374
- # Setup video writer
375
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
376
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
377
-
378
- frame_count = 0
379
- while True:
380
- ret, frame = cap.read()
381
- if not ret:
382
- break
383
-
384
- # Apply green screen using SAM2 mask
385
- green_bg = np.zeros_like(frame)
386
- green_bg[:, :] = [0, 255, 0]
387
-
388
- mask_3d = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR).astype(np.float32) / 255.0
389
- result_frame = frame.astype(np.float32) * mask_3d + green_bg.astype(np.float32) * (1 - mask_3d)
390
-
391
- out.write(result_frame.astype(np.uint8))
392
-
393
- frame_count += 1
394
- if frame_count % 30 == 0:
395
- progress = (frame_count / total_frames) * 100
396
- logger.info(f"SAM2 processing: {progress:.1f}% ({frame_count}/{total_frames})")
397
-
398
  cap.release()
399
- out.release()
400
-
401
- return True, "SAM2 + OpenCV segmentation completed successfully"
402
-
403
- except Exception as e:
404
- logger.error(f"Error applying SAM2 mask: {e}")
405
- return False, f"SAM2 mask application error: {str(e)}"
 
 
 
 
 
 
 
 
 
406
 
407
- def fallback_segmentation(self, video_path: str, output_path: str) -> Tuple[bool, str]:
408
- """Simple but effective segmentation that works with ANY background"""
409
- try:
410
- logger.info("🎯 Using robust universal segmentation...")
411
-
412
- cap = cv2.VideoCapture(video_path)
413
- if not cap.isOpened():
414
- return False, "Could not open video file"
415
-
416
- # Get video properties
417
- fps = int(cap.get(cv2.CAP_PROP_FPS))
418
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
419
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
420
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
421
-
422
- # Setup video writer
423
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
424
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
425
-
426
- logger.info(f"📊 Video: {width}x{height}, {fps}fps, {total_frames} frames")
427
-
428
- frame_count = 0
429
- while True:
430
- ret, frame = cap.read()
431
- if not ret:
432
- break
433
-
434
- # Create person mask using multiple methods combined
435
- mask = self.create_universal_person_mask(frame, width, height)
436
-
437
- # Apply green screen
438
- result_frame = self.apply_green_screen_robust(frame, mask)
439
- out.write(result_frame)
440
-
441
- frame_count += 1
442
- if frame_count % 30 == 0:
443
- progress = (frame_count / total_frames) * 100
444
- logger.info(f"Universal processing: {progress:.1f}% ({frame_count}/{total_frames})")
445
-
446
- cap.release()
447
- out.release()
448
-
449
- logger.info(f"✅ Universal segmentation completed: {output_path}")
450
- return True, "Universal segmentation completed successfully"
451
-
452
- except Exception as e:
453
- logger.error(f"❌ Error in universal segmentation: {e}")
454
- return False, f"Universal segmentation error: {str(e)}"
455
 
456
- def create_universal_person_mask(self, frame, width, height) -> np.ndarray:
457
- """Create person mask using fast optimized method"""
458
-
459
- # Use ONLY the fastest method that still works well
460
- mask = self.fast_grabcut_segmentation(frame, width, height)
461
-
462
- # Quick cleanup only
463
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
464
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
465
-
466
- # Light smoothing
467
- mask = cv2.GaussianBlur(mask, (7, 7), 0)
468
-
469
- return mask
470
 
471
- def fast_grabcut_segmentation(self, frame, width, height) -> np.ndarray:
472
- """Fast GrabCut with minimal iterations"""
473
- try:
474
- # Create rectangle around likely person area (center 60% of frame)
475
- margin_x = int(width * 0.2)
476
- margin_y = int(height * 0.1)
477
- rect = (margin_x, margin_y, width - 2*margin_x, height - 2*margin_y)
478
-
479
- # Initialize mask
480
- mask = np.zeros((height, width), np.uint8)
481
- bgd_model = np.zeros((1, 65), np.float64)
482
- fgd_model = np.zeros((1, 65), np.float64)
483
-
484
- # Apply GrabCut with ONLY 2 iterations (much faster)
485
- cv2.grabCut(frame, mask, rect, bgd_model, fgd_model, 2, cv2.GC_INIT_WITH_RECT)
486
-
487
- # Create binary mask
488
- mask2 = np.where((mask == 2) | (mask == 0), 0, 255).astype('uint8')
489
-
490
- return mask2
491
-
492
- except Exception as e:
493
- logger.warning(f"Fast GrabCut failed: {e}, using simple fallback")
494
- # Ultra-simple fallback
495
- mask = np.zeros((height, width), dtype=np.uint8)
496
- margin_x = int(width * 0.25)
497
- margin_y = int(height * 0.15)
498
- mask[margin_y:height-margin_y, margin_x:width-margin_x] = 255
499
- return mask
500
 
501
- def apply_green_screen_robust(self, frame, mask) -> np.ndarray:
502
- """Apply green screen with robust blending"""
503
- # Create green background
504
- green_bg = np.zeros_like(frame)
505
- green_bg[:, :] = [0, 255, 0] # Green background (BGR format)
506
-
507
- # Ensure mask is 3-channel
508
- if len(mask.shape) == 2:
509
- mask_3d = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR).astype(np.float32) / 255.0
510
- else:
511
- mask_3d = mask.astype(np.float32) / 255.0
512
-
513
- # Robust blending with smooth transitions
514
- # Person (mask=1) keeps original color, background (mask=0) becomes green
515
- result_frame = frame.astype(np.float32) * mask_3d + green_bg.astype(np.float32) * (1 - mask_3d)
516
-
517
- return result_frame.astype(np.uint8)
518
 
519
- # Initialize SAM2 segmenter
520
- sam2_segmenter = SAM2Segmenter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
 
522
- # ===============================================================================
523
- # MYAVATAR API INTEGRATION
524
- # ===============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  class MyAvatarAPI:
527
- """MyAvatar API integration for video management"""
528
 
529
  def __init__(self):
530
  self.api_base = "https://app.myavatar.dk/api"
@@ -534,7 +257,6 @@ def __init__(self):
534
  def fetch_videos(self) -> List[Dict[str, Any]]:
535
  """Fetch videos from MyAvatar API"""
536
  try:
537
- # Cache for 5 minutes
538
  if time.time() - self.last_refresh < 300 and self.videos_cache:
539
  return self.videos_cache
540
 
@@ -543,14 +265,14 @@ def fetch_videos(self) -> List[Dict[str, Any]]:
543
  data = response.json()
544
  self.videos_cache = data.get('videos', [])
545
  self.last_refresh = time.time()
546
- logger.info(f"Fetched {len(self.videos_cache)} videos from MyAvatar")
547
  return self.videos_cache
548
  else:
549
- logger.error(f"API error: {response.status_code}")
550
  return []
551
 
552
  except Exception as e:
553
- logger.error(f"Error fetching videos: {e}")
554
  return []
555
 
556
  def get_video_choices(self) -> List[str]:
@@ -574,11 +296,9 @@ def get_video_url(self, selection: str) -> Optional[str]:
574
  return None
575
 
576
  try:
577
- # Extract ID from selection
578
  if "(ID: " in selection:
579
  video_id = selection.split("(ID: ")[1].split(")")[0]
580
 
581
- # Find video in cache
582
  for video in self.videos_cache:
583
  if str(video.get('id')) == video_id:
584
  return video.get('video_url')
@@ -586,16 +306,12 @@ def get_video_url(self, selection: str) -> Optional[str]:
586
  return None
587
 
588
  except Exception as e:
589
- logger.error(f"Error extracting video URL: {e}")
590
  return None
591
 
592
- # Initialize MyAvatar API
593
  myavatar_api = MyAvatarAPI()
594
 
595
- # ===============================================================================
596
- # BACKGROUND PROCESSING
597
- # ===============================================================================
598
-
599
  def create_gradient_background(gradient_type: str, width: int, height: int) -> Image.Image:
600
  """Create gradient backgrounds"""
601
  try:
@@ -634,7 +350,6 @@ def create_gradient_background(gradient_type: str, width: int, height: int) -> I
634
 
635
  except Exception as e:
636
  logger.error(f"Error creating gradient: {e}")
637
- # Return solid blue as fallback
638
  img = np.full((height, width, 3), [70, 130, 180], dtype=np.uint8)
639
  return Image.fromarray(img)
640
 
@@ -655,178 +370,33 @@ def create_solid_color(color: str, width: int, height: int) -> Image.Image:
655
  img = np.full((height, width, 3), rgb, dtype=np.uint8)
656
  return Image.fromarray(img)
657
 
658
- def replace_green_screen(video_path: str, background_image: Image.Image, output_path: str) -> Tuple[bool, str]:
659
- """Replace green screen in video with new background using OpenCV only"""
660
- try:
661
- # Open video capture
662
- cap = cv2.VideoCapture(video_path)
663
- if not cap.isOpened():
664
- return False, "Could not open video file"
665
-
666
- # Get video properties
667
- fps = int(cap.get(cv2.CAP_PROP_FPS))
668
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
669
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
670
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
671
-
672
- # Resize background to match video dimensions
673
- background_resized = background_image.resize((width, height), Image.Resampling.LANCZOS)
674
- bg_array = np.array(background_resized)
675
-
676
- # Create temporary video without audio first
677
- temp_video_path = output_path.replace('.mp4', '_no_audio.mp4')
678
-
679
- # Setup video writer
680
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
681
- out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height))
682
-
683
- frame_count = 0
684
- while True:
685
- ret, frame = cap.read()
686
- if not ret:
687
- break
688
-
689
- # Convert BGR to RGB for consistency
690
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
691
-
692
- # Convert to HSV for better green detection
693
- hsv = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2HSV)
694
-
695
- # Define green range (adjusted for green screen)
696
- lower_green = np.array([40, 50, 50])
697
- upper_green = np.array([80, 255, 255])
698
-
699
- # Create mask
700
- mask = cv2.inRange(hsv, lower_green, upper_green)
701
-
702
- # Improve mask with morphological operations
703
- kernel = np.ones((3, 3), np.uint8)
704
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
705
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
706
-
707
- # Apply Gaussian blur for smoother edges
708
- mask = cv2.GaussianBlur(mask, (5, 5), 0)
709
- mask = mask.astype(np.float32) / 255.0 # Normalize to 0-1
710
-
711
- # Create 3-channel mask
712
- mask_3d = np.stack([mask, mask, mask], axis=2)
713
-
714
- # Blend frame with background
715
- result = frame_rgb * (1 - mask_3d) + bg_array * mask_3d
716
- result = result.astype(np.uint8)
717
-
718
- # Convert back to BGR for video writer
719
- result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
720
- out.write(result_bgr)
721
-
722
- frame_count += 1
723
- if frame_count % 30 == 0: # Log progress every 30 frames
724
- progress = (frame_count / total_frames) * 100
725
- logger.info(f"Processing: {progress:.1f}% ({frame_count}/{total_frames})")
726
-
727
- # Cleanup
728
- cap.release()
729
- out.release()
730
-
731
- # Step 4: Add audio back using ffmpeg
732
- logger.info("🔊 Adding audio back to final video...")
733
- success = add_audio_to_video(video_path, temp_video_path, output_path)
734
-
735
- # Cleanup temporary file
736
- try:
737
- os.unlink(temp_video_path)
738
- except:
739
- pass
740
-
741
- if success:
742
- logger.info(f"✅ Green screen replacement with audio completed: {output_path}")
743
- return True, "Background replacement with audio completed successfully"
744
- else:
745
- logger.warning("⚠️ Audio addition failed, but video processing completed")
746
- # Move temp file to final output as fallback
747
- try:
748
- os.rename(temp_video_path, output_path)
749
- except:
750
- pass
751
- return True, "Background replacement completed (audio may be missing)"
752
-
753
- except Exception as e:
754
- logger.error(f"❌ Error in green screen replacement: {e}")
755
- return False, f"Background replacement error: {str(e)}"
756
-
757
- def add_audio_to_video(source_video: str, video_no_audio: str, output_path: str) -> bool:
758
- """Add audio from source video to processed video using ffmpeg"""
759
- try:
760
- # Check if ffmpeg is available
761
- try:
762
- subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
763
- except (subprocess.CalledProcessError, FileNotFoundError):
764
- logger.warning("⚠️ ffmpeg not available, skipping audio")
765
- return False
766
-
767
- # FFmpeg command to combine video (no audio) with audio from original
768
- cmd = [
769
- 'ffmpeg', '-y', # -y to overwrite output file
770
- '-i', video_no_audio, # Video input (no audio)
771
- '-i', source_video, # Audio source
772
- '-c:v', 'copy', # Copy video codec
773
- '-c:a', 'aac', # Audio codec
774
- '-map', '0:v:0', # Use video from first input
775
- '-map', '1:a:0', # Use audio from second input
776
- '-shortest', # End when shortest stream ends
777
- output_path
778
- ]
779
-
780
- # Run ffmpeg
781
- result = subprocess.run(cmd, capture_output=True, text=True)
782
-
783
- if result.returncode == 0:
784
- logger.info("✅ Audio successfully added to video")
785
- return True
786
- else:
787
- logger.error(f"❌ ffmpeg error: {result.stderr}")
788
- return False
789
-
790
- except Exception as e:
791
- logger.error(f"❌ Error adding audio: {e}")
792
- return False
793
-
794
- # ===============================================================================
795
- # AI BACKGROUND GENERATION
796
- # ===============================================================================
797
-
798
  def generate_ai_background(prompt: str) -> Tuple[Optional[Image.Image], str]:
799
  """Generate AI background using Hugging Face Inference API"""
800
  try:
801
  if not prompt.strip():
802
  return None, "Please enter a prompt"
803
 
804
- # Try multiple AI models for image generation
805
  models = [
806
- "black-forest-labs/FLUX.1-schnell", # Fast FLUX model
807
- "stabilityai/stable-diffusion-xl-base-1.0", # SDXL
808
- "runwayml/stable-diffusion-v1-5" # SD 1.5 fallback
809
  ]
810
 
811
- # Enhanced prompt for backgrounds
812
  enhanced_prompt = f"professional video background, {prompt}, high quality, 16:9 aspect ratio, cinematic lighting, detailed"
813
 
814
  for model in models:
815
  try:
816
- logger.info(f"🎨 Trying AI generation with {model}...")
817
 
818
- # Hugging Face Inference API
819
  api_url = f"https://api-inference.huggingface.co/models/{model}"
820
-
821
  headers = {
822
  "Authorization": f"Bearer {os.getenv('HUGGINGFACE_TOKEN', 'hf_placeholder')}"
823
  }
824
-
825
  payload = {
826
  "inputs": enhanced_prompt,
827
  "parameters": {
828
  "width": 1024,
829
- "height": 576, # 16:9 aspect ratio
830
  "num_inference_steps": 20,
831
  "guidance_scale": 7.5
832
  }
@@ -835,27 +405,22 @@ def generate_ai_background(prompt: str) -> Tuple[Optional[Image.Image], str]:
835
  response = requests.post(api_url, headers=headers, json=payload, timeout=30)
836
 
837
  if response.status_code == 200:
838
- # Success! Convert response to image
839
  image = Image.open(io.BytesIO(response.content))
840
- logger.info(f"AI background generated successfully with {model}")
841
- return image, f"AI background generated: {prompt}"
842
-
843
  elif response.status_code == 503:
844
- # Model loading, try next
845
- logger.warning(f"⏳ Model {model} is loading, trying next...")
846
  continue
847
-
848
  else:
849
- logger.warning(f"⚠️ Error with {model}: {response.status_code}")
850
  continue
851
 
852
  except Exception as e:
853
- logger.warning(f"⚠️ Error with {model}: {e}")
854
  continue
855
 
856
- # If all AI models fail, create an intelligent gradient fallback
857
- logger.info("🔄 AI generation failed, creating intelligent gradient fallback...")
858
- return create_intelligent_gradient(prompt), f"✅ Created gradient background inspired by: {prompt}"
859
 
860
  except Exception as e:
861
  logger.error(f"Error in AI background generation: {e}")
@@ -865,47 +430,16 @@ def create_intelligent_gradient(prompt: str) -> Image.Image:
865
  """Create intelligent gradient based on prompt analysis"""
866
  prompt_lower = prompt.lower()
867
 
868
- # Analyze prompt for colors and themes
869
  if any(word in prompt_lower for word in ["sunset", "orange", "warm", "fire", "autumn"]):
870
  return create_gradient_background("sunset", 1920, 1080)
871
  elif any(word in prompt_lower for word in ["ocean", "sea", "blue", "water", "sky", "calm"]):
872
  return create_gradient_background("ocean", 1920, 1080)
873
  elif any(word in prompt_lower for word in ["forest", "green", "nature", "trees", "jungle"]):
874
  return create_gradient_background("forest", 1920, 1080)
875
- elif any(word in prompt_lower for word in ["night", "dark", "purple", "space", "cosmic"]):
876
- return create_cosmic_gradient(1920, 1080)
877
- elif any(word in prompt_lower for word in ["professional", "business", "corporate", "office"]):
878
- return create_professional_gradient(1920, 1080)
879
  else:
880
  return create_gradient_background("default", 1920, 1080)
881
 
882
- def create_cosmic_gradient(width: int, height: int) -> Image.Image:
883
- """Create a cosmic/space gradient"""
884
- img = np.zeros((height, width, 3), dtype=np.uint8)
885
- for i in range(height):
886
- ratio = i / height
887
- r = int(25 * (1 - ratio) + 75 * ratio)
888
- g = int(25 * (1 - ratio) + 0 * ratio)
889
- b = int(112 * (1 - ratio) + 130 * ratio)
890
- img[i, :] = [r, g, b]
891
- return Image.fromarray(img)
892
-
893
- def create_professional_gradient(width: int, height: int) -> Image.Image:
894
- """Create a professional business gradient"""
895
- img = np.zeros((height, width, 3), dtype=np.uint8)
896
- for i in range(height):
897
- ratio = i / height
898
- r = int(240 * (1 - ratio) + 200 * ratio)
899
- g = int(240 * (1 - ratio) + 200 * ratio)
900
- b = int(240 * (1 - ratio) + 200 * ratio)
901
- img[i, :] = [r, g, b]
902
- return Image.fromarray(img)
903
-
904
- # ===============================================================================
905
- # MAIN PROCESSING FUNCTIONS
906
- # ===============================================================================
907
-
908
- def process_video_with_background(
909
  input_video: Optional[str],
910
  myavatar_selection: str,
911
  background_type: str,
@@ -913,10 +447,16 @@ def process_video_with_background(
913
  solid_color: str,
914
  custom_background: Optional[str],
915
  ai_prompt: str
916
- ) -> Tuple[Optional[str], str]:
917
- """Main video processing function"""
 
 
 
918
  try:
919
- # Determine input video source
 
 
 
920
  video_path = None
921
  if input_video:
922
  video_path = input_video
@@ -924,7 +464,6 @@ def process_video_with_background(
924
  elif myavatar_selection and myavatar_selection != "No videos available":
925
  video_url = myavatar_api.get_video_url(myavatar_selection)
926
  if video_url:
927
- # Download video temporarily
928
  response = requests.get(video_url)
929
  if response.status_code == 200:
930
  temp_video = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
@@ -932,28 +471,15 @@ def process_video_with_background(
932
  temp_video.close()
933
  video_path = temp_video.name
934
  logger.info("Using MyAvatar video")
935
- else:
936
- return None, "❌ Failed to download MyAvatar video"
937
- else:
938
- return None, "❌ Could not get video URL from MyAvatar"
939
 
940
  if not video_path:
941
- return None, "No video provided"
942
-
943
- # Step 1: Create green screen version using SAM2
944
- with tempfile.NamedTemporaryFile(suffix='_greenscreen.mp4', delete=False) as tmp_green:
945
- green_video_path = tmp_green.name
946
 
947
- logger.info("🎬 Step 1: Creating green screen version with SAM2...")
948
- success, message = sam2_segmenter.segment_video(video_path, green_video_path)
949
 
950
- if not success:
951
- return None, f"❌ SAM2 segmentation failed: {message}"
952
-
953
- # Step 2: Generate background
954
- logger.info("🎨 Step 2: Generating background...")
955
  background_image = None
956
-
957
  if background_type == "gradient":
958
  background_image = create_gradient_background(gradient_type, 1920, 1080)
959
  elif background_type == "solid":
@@ -962,81 +488,75 @@ def process_video_with_background(
962
  background_image = Image.open(custom_background)
963
  elif background_type == "ai" and ai_prompt:
964
  bg_img, ai_msg = generate_ai_background(ai_prompt)
965
- if bg_img:
966
- background_image = bg_img
967
- else:
968
- return None, f"❌ AI background generation failed: {ai_msg}"
969
 
970
  if not background_image:
971
- return None, "No background generated"
 
 
 
 
 
 
972
 
973
- # Step 3: Replace green screen with background
974
- logger.info("🔄 Step 3: Replacing green screen with background...")
975
- with tempfile.NamedTemporaryFile(suffix='_final.mp4', delete=False) as tmp_final:
976
  final_video_path = tmp_final.name
977
 
978
- success, message = replace_green_screen(green_video_path, background_image, final_video_path)
979
 
980
- # Cleanup temporary files
981
  try:
982
- os.unlink(green_video_path)
983
- if video_path != input_video: # Don't delete uploaded file
984
  os.unlink(video_path)
985
  except:
986
  pass
987
 
988
- if success:
989
- logger.info("✅ Video processing completed successfully!")
990
- return final_video_path, "✅ Video processing completed successfully!"
991
  else:
992
- return None, f" Background replacement failed: {message}"
993
 
994
  except Exception as e:
995
- logger.error(f"Error in video processing: {e}")
996
- return None, f"Processing error: {str(e)}"
997
-
998
- # ===============================================================================
999
- # GRADIO INTERFACE
1000
- # ===============================================================================
1001
 
1002
  def create_interface():
1003
  """Create the Gradio interface"""
1004
  logger.info("Creating Gradio interface...")
1005
- logger.info(f"Device: {DEVICE} | GPU: {GPU_NAME} | Memory: {GPU_MEMORY:.1f}GB | Type: {MODEL_SIZE}")
1006
 
1007
- # Custom CSS
1008
  css = """
1009
  .main-container { max-width: 1200px; margin: 0 auto; }
1010
  .status-box { border: 2px solid #4CAF50; border-radius: 10px; padding: 15px; }
1011
  .gradient-preview { border: 2px solid #ddd; border-radius: 10px; }
1012
  """
1013
 
1014
- with gr.Blocks(css=css, title="🎬 BackgroundFX Pro") as app:
1015
 
1016
- # Header
1017
  gr.Markdown("""
1018
- # 🎬 BackgroundFX Pro
1019
- ### Professional Video Background Replacement with SAM2 Segmentation
1020
  """)
1021
 
1022
- # System Status
1023
  with gr.Row():
 
 
1024
  gr.Markdown(f"""
1025
- **System Status:** 🟢 Online | **GPU:** {GPU_NAME} | **SAM2:** {'✅ Ready' if SAM2_AVAILABLE else '❌ Not Available'}
1026
  """)
1027
 
1028
- # Main Interface
1029
  with gr.Row():
1030
- # Left Column - Input
1031
  with gr.Column(scale=1):
1032
- gr.Markdown("## 📹 Video Input")
1033
 
1034
  with gr.Tabs():
1035
- with gr.Tab("📁 Upload Video"):
1036
  video_upload = gr.Video(label="Upload Video File", height=300)
1037
 
1038
- with gr.Tab("📱 MyAvatar Videos"):
1039
- refresh_btn = gr.Button("🔄 Refresh Videos", size="sm")
1040
  myavatar_dropdown = gr.Dropdown(
1041
  label="Select MyAvatar Video",
1042
  choices=["Click refresh to load videos"],
@@ -1044,7 +564,7 @@ def create_interface():
1044
  )
1045
  video_preview = gr.Video(label="Preview", height=200)
1046
 
1047
- gr.Markdown("## 🎨 Background Options")
1048
 
1049
  background_type = gr.Radio(
1050
  choices=["gradient", "solid", "custom", "ai"],
@@ -1053,7 +573,6 @@ def create_interface():
1053
  )
1054
 
1055
  with gr.Group():
1056
- # Gradient options
1057
  gradient_type = gr.Dropdown(
1058
  choices=["sunset", "ocean", "forest", "default"],
1059
  value="sunset",
@@ -1062,7 +581,6 @@ def create_interface():
1062
  )
1063
  gradient_preview = gr.Image(label="Gradient Preview", height=150)
1064
 
1065
- # Solid color options
1066
  solid_color = gr.Dropdown(
1067
  choices=["white", "black", "blue", "green", "red", "purple", "orange", "yellow"],
1068
  value="blue",
@@ -1071,28 +589,26 @@ def create_interface():
1071
  )
1072
  color_preview = gr.Image(label="Color Preview", height=150, visible=False)
1073
 
1074
- # Custom background upload
1075
  custom_bg_upload = gr.Image(
1076
  label="Upload Custom Background",
1077
  type="filepath",
1078
  visible=False
1079
  )
1080
 
1081
- # AI generation
1082
  ai_prompt = gr.Textbox(
1083
  label="AI Background Prompt",
1084
  placeholder="Describe the background you want...",
1085
  visible=False
1086
  )
1087
- ai_generate_btn = gr.Button("🤖 Generate AI Background", visible=False)
1088
  ai_preview = gr.Image(label="AI Generated Background", height=150, visible=False)
1089
 
1090
- # Process button
1091
- process_btn = gr.Button("🎬 Process Video", variant="primary", size="lg")
 
1092
 
1093
- # Right Column - Output
1094
  with gr.Column(scale=1):
1095
- gr.Markdown("## 🎯 Results")
1096
 
1097
  result_video = gr.Video(label="Processed Video", height=400)
1098
 
@@ -1103,21 +619,18 @@ def create_interface():
1103
  elem_classes=["status-box"]
1104
  )
1105
 
1106
- # Processing info
1107
  gr.Markdown("""
1108
- ### 🔧 Processing Steps:
1109
- 1. **SAM2 Segmentation** - Extract person from video
1110
- 2. **Green Screen Creation** - Replace background with green
1111
- 3. **Background Replacement** - Apply your chosen background
1112
- 4. **Final Rendering** - Output processed video
1113
 
1114
- **Estimated Time:** 2-5 minutes depending on video length
1115
  """)
1116
 
1117
- # ===== EVENT HANDLERS (All defined after components) =====
1118
-
1119
  def update_background_options(bg_type):
1120
- """Update visible background options based on type"""
1121
  return {
1122
  gradient_type: gr.update(visible=(bg_type == "gradient")),
1123
  gradient_preview: gr.update(visible=(bg_type == "gradient")),
@@ -1130,23 +643,18 @@ def update_background_options(bg_type):
1130
  }
1131
 
1132
  def update_gradient_preview(grad_type):
1133
- """Update gradient preview"""
1134
  try:
1135
- img = create_gradient_background(grad_type, 400, 200)
1136
- return img
1137
  except:
1138
  return None
1139
 
1140
  def update_color_preview(color):
1141
- """Update solid color preview"""
1142
  try:
1143
- img = create_solid_color(color, 400, 200)
1144
- return img
1145
  except:
1146
  return None
1147
 
1148
  def refresh_myavatar_videos():
1149
- """Refresh MyAvatar video list"""
1150
  try:
1151
  choices = myavatar_api.get_video_choices()
1152
  return gr.update(choices=choices, value=None)
@@ -1155,7 +663,6 @@ def refresh_myavatar_videos():
1155
  return gr.update(choices=["Error loading videos"])
1156
 
1157
  def load_video_preview(selection):
1158
- """Load video preview from MyAvatar selection"""
1159
  try:
1160
  if not selection or selection == "No videos available":
1161
  return None
@@ -1167,7 +674,6 @@ def load_video_preview(selection):
1167
  return None
1168
 
1169
  def generate_ai_bg(prompt):
1170
- """Generate AI background"""
1171
  bg_img, message = generate_ai_background(prompt)
1172
  return bg_img
1173
 
@@ -1209,7 +715,7 @@ def generate_ai_bg(prompt):
1209
  )
1210
 
1211
  process_btn.click(
1212
- fn=process_video_with_background,
1213
  inputs=[
1214
  video_upload,
1215
  myavatar_dropdown,
@@ -1219,10 +725,14 @@ def generate_ai_bg(prompt):
1219
  custom_bg_upload,
1220
  ai_prompt
1221
  ],
1222
- outputs=[result_video, status_output]
 
 
 
 
 
1223
  )
1224
 
1225
- # Initialize gradient preview
1226
  app.load(
1227
  fn=lambda: create_gradient_background("sunset", 400, 200),
1228
  outputs=[gradient_preview]
@@ -1230,19 +740,16 @@ def generate_ai_bg(prompt):
1230
 
1231
  return app
1232
 
1233
- # ===============================================================================
1234
- # MAIN APPLICATION
1235
- # ===============================================================================
1236
-
1237
  def main():
1238
  """Main application entry point"""
1239
  try:
1240
- # Pre-load AI models
1241
- if SAM2_AVAILABLE or MATANYONE_AVAILABLE:
1242
- logger.info("Pre-loading AI models...")
1243
- sam2_segmenter.load_models()
 
 
1244
 
1245
- # Create and launch interface
1246
  app = create_interface()
1247
 
1248
  app.launch(
@@ -1254,7 +761,7 @@ def main():
1254
  )
1255
 
1256
  except Exception as e:
1257
- logger.error(f"Failed to start application: {e}")
1258
  sys.exit(1)
1259
 
1260
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
  """
3
+ BackgroundFX Pro - GPU Optimized Version
4
+ Professional video background replacement with SAM2 + MatAnyone
5
  """
6
 
7
  import os
 
15
  import requests
16
  import tempfile
17
  import subprocess
18
+ import threading
19
  import numpy as np
20
  import io
21
  from PIL import Image
 
25
 
26
  import gradio as gr
27
 
28
+ # Import optimized modules
29
+ from utils.accelerator import pick_device, torch_global_tuning, memory_checkpoint, cleanup
30
+ from models.sam2_loader import SAM2Predictor
31
+ from models.matanyone_loader import MatAnyoneSession
32
+
33
+ # Configure logging
34
  logging.basicConfig(
35
  level=logging.INFO,
36
  format='%(asctime)s - %(levelname)s - %(message)s'
 
42
  SKLEARN_AVAILABLE = True
43
  except ImportError:
44
  SKLEARN_AVAILABLE = False
45
+ logger.warning("sklearn not available, using fallback color detection")
46
+
47
+ # Global processing control
48
+ processing_active = False
49
+ processing_thread = None
50
+
51
+ # Initialize optimized system
52
+ device = pick_device()
53
+ torch_global_tuning()
54
+ GPU_NAME = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
55
+ GPU_MEMORY = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0
56
+ MODEL_SIZE = "large" if "T4" in GPU_NAME else "base"
57
+
58
+ logger.info(f"System initialized - Device: {device} | GPU: {GPU_NAME} | Memory: {GPU_MEMORY:.1f}GB")
59
+
60
+ # Environment variables for model control
61
+ SAM2_ENABLED = os.environ.get("ENABLE_SAM2", "1") == "1"
62
+ MATANY_ENABLED = os.environ.get("ENABLE_MATANY", "1") == "1"
63
+ MAX_SIDE = int(os.environ.get("MAX_SIDE", "1280"))
64
+ FRAME_CHUNK = int(os.environ.get("FRAME_CHUNK", "64"))
65
+
66
+ # Global optimized model instances
67
+ sam2_predictor = None
68
+ matanyone_session = None
69
+
70
+ def get_sam2():
71
+ """Get SAM2 predictor with lazy loading"""
72
+ global sam2_predictor
73
+ if sam2_predictor is None and SAM2_ENABLED:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  try:
75
+ sam2_predictor = SAM2Predictor(device).load()
76
+ logger.info("SAM2 loaded with optimized pipeline")
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
+ logger.error(f"SAM2 loading failed: {e}")
79
+ sam2_predictor = None
80
+ return sam2_predictor
81
+
82
+ def get_matanyone():
83
+ """Get MatAnyone session with lazy loading"""
84
+ global matanyone_session
85
+ if matanyone_session is None and MATANY_ENABLED:
86
  try:
87
+ repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone")
88
+ filename = os.environ.get("MATANY_FILENAME", "matanyone_v1.0.pth")
89
+ matanyone_session = MatAnyoneSession(device).load(
90
+ repo_id=repo_id,
91
+ filename=filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
+ logger.info("MatAnyone loaded with optimized pipeline")
 
 
 
 
 
 
94
  except Exception as e:
95
+ logger.error(f"MatAnyone loading failed: {e}")
96
+ matanyone_session = None
97
+ return matanyone_session
98
+
99
+ def iter_video_frames(path, target_max_side=MAX_SIDE, chunk=FRAME_CHUNK):
100
+ """Memory-mapped video frame generator"""
101
+ import cv2
102
+ cap = cv2.VideoCapture(path)
103
+ if not cap.isOpened():
104
+ raise RuntimeError("Cannot open video")
105
+
106
+ # Get video properties
107
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
108
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
109
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
110
+
111
+ # Scale to fit GPU memory constraints
112
+ scale = min(1.0, float(target_max_side) / float(max(w, h)))
113
+ new_w, new_h = (w, h) if scale >= 0.999 else (int(w*scale)//2*2, int(h*scale)//2*2)
114
+
115
+ batch = []
116
+ while True:
117
+ if not processing_active:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  cap.release()
119
+ return
120
+
121
+ ok, f = cap.read()
122
+ if not ok:
123
+ if batch:
124
+ yield batch, fps, (w, h), (new_w, new_h)
125
+ break
126
+
127
+ if new_w != w or new_h != h:
128
+ f = cv2.resize(f, (new_w, new_h), interpolation=cv2.INTER_AREA)
129
+ f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
130
+ batch.append(f)
131
+
132
+ if len(batch) >= chunk:
133
+ yield batch, fps, (w, h), (new_w, new_h)
134
+ batch = []
135
 
136
+ cap.release()
137
+
138
+ def composite_frame(frame_rgb, bg_rgb, alpha01):
139
+ """GPU-optimized frame compositing"""
140
+ if bg_rgb is None:
141
+ bg = np.full_like(frame_rgb, 200, dtype=np.uint8)
142
+ else:
143
+ bg = bg_rgb
144
+ if bg.shape[:2] != frame_rgb.shape[:2]:
145
+ bg = cv2.resize(bg, (frame_rgb.shape[1], frame_rgb.shape[0]), interpolation=cv2.INTER_AREA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ a = np.clip(alpha01[..., None], 0.0, 1.0)
148
+ out = (frame_rgb.astype("float32") * a + bg.astype("float32") * (1.0 - a)).astype("uint8")
149
+ return out
150
+
151
+ def cheap_fallback_alpha(fr, seed_mask=None):
152
+ """Fast CPU fallback alpha generation"""
153
+ if seed_mask is not None:
154
+ return seed_mask
 
 
 
 
 
 
155
 
156
+ # Center-focused soft alpha
157
+ H, W = fr.shape[:2]
158
+ yy, xx = np.mgrid[0:H, 0:W].astype("float32")
159
+ cx, cy = W/2.0, H/2.0
160
+ r = np.sqrt((xx-cx)**2 + (yy-cy)**2) / max(W, H)
161
+ a = 1.0 - np.clip((r-0.2)/0.4, 0.0, 1.0)
162
+ return a.astype("float32")
163
+
164
+ def process_video_gpu_optimized(input_path, bg_image_rgb=None, out_path="output.mp4"):
165
+ """GPU-optimized video processing pipeline"""
166
+ global processing_active
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ writer = None
169
+ seed_mask = None
170
+ total = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ try:
173
+ for frames, fps, orig_hw, new_hw in iter_video_frames(input_path, MAX_SIDE, FRAME_CHUNK):
174
+ if not processing_active:
175
+ logger.info("Processing stopped by user")
176
+ break
177
+
178
+ H, W = frames[0].shape[:2]
179
+ if writer is None:
180
+ writer = cv2.VideoWriter(
181
+ out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (W, H)
182
+ )
183
+
184
+ # First frame: try SAM2 for seed mask
185
+ if seed_mask is None:
186
+ try:
187
+ sam2 = get_sam2()
188
+ if sam2:
189
+ seed_mask = sam2.first_frame_mask(frames[0].astype("float32") / 255.0)
190
+ seed_mask = (cv2.GaussianBlur(seed_mask, (0, 0), 1.0) > 0.5).astype("float32")
191
+ logger.info("SAM2 seed mask generated")
192
+ except Exception as e:
193
+ logger.warning(f"SAM2 failed, continuing without: {e}")
194
+ seed_mask = None
195
+
196
+ # Professional matting pipeline
197
+ matany = get_matanyone()
198
+ if matany and MATANY_ENABLED:
199
+ try:
200
+ with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.float16, enabled=(device.type=="cuda")):
201
+ for i, fr in enumerate(frames):
202
+ if not processing_active:
203
+ break
204
+
205
+ alpha = matany.step(fr, seed_mask if total == 0 and i == 0 else None)
206
+ comp = composite_frame(fr, bg_image_rgb, alpha)
207
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
208
+ total += 1
209
+
210
+ if total % 64 == 0:
211
+ cleanup()
212
+ memory_checkpoint(f"frames={total}")
213
+
214
+ except Exception as e:
215
+ logger.warning(f"MatAnyone failed: {e}")
216
+ matany = None
217
+
218
+ # Fallback if MatAnyone unavailable
219
+ if not matany:
220
+ for fr in frames:
221
+ if not processing_active:
222
+ break
223
+
224
+ alpha = cheap_fallback_alpha(fr, seed_mask)
225
+ comp = composite_frame(fr, bg_image_rgb, alpha)
226
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR))
227
+ total += 1
228
+
229
+ if total % 64 == 0:
230
+ cleanup()
231
 
232
+ memory_checkpoint(f"processed={total}")
233
+
234
+ except Exception as e:
235
+ logger.error(f"Processing error: {e}")
236
+ finally:
237
+ if writer:
238
+ writer.release()
239
+ cleanup()
240
+
241
+ return out_path if processing_active else None
242
+
243
+ def stop_processing():
244
+ """Stop video processing"""
245
+ global processing_active
246
+ processing_active = False
247
+ return gr.update(visible=False), "Processing stopped by user"
248
 
249
  class MyAvatarAPI:
250
+ """MyAvatar API integration"""
251
 
252
  def __init__(self):
253
  self.api_base = "https://app.myavatar.dk/api"
 
257
  def fetch_videos(self) -> List[Dict[str, Any]]:
258
  """Fetch videos from MyAvatar API"""
259
  try:
 
260
  if time.time() - self.last_refresh < 300 and self.videos_cache:
261
  return self.videos_cache
262
 
 
265
  data = response.json()
266
  self.videos_cache = data.get('videos', [])
267
  self.last_refresh = time.time()
268
+ logger.info(f"Fetched {len(self.videos_cache)} videos from MyAvatar")
269
  return self.videos_cache
270
  else:
271
+ logger.error(f"API error: {response.status_code}")
272
  return []
273
 
274
  except Exception as e:
275
+ logger.error(f"Error fetching videos: {e}")
276
  return []
277
 
278
  def get_video_choices(self) -> List[str]:
 
296
  return None
297
 
298
  try:
 
299
  if "(ID: " in selection:
300
  video_id = selection.split("(ID: ")[1].split(")")[0]
301
 
 
302
  for video in self.videos_cache:
303
  if str(video.get('id')) == video_id:
304
  return video.get('video_url')
 
306
  return None
307
 
308
  except Exception as e:
309
+ logger.error(f"Error extracting video URL: {e}")
310
  return None
311
 
312
+ # Initialize API
313
  myavatar_api = MyAvatarAPI()
314
 
 
 
 
 
315
  def create_gradient_background(gradient_type: str, width: int, height: int) -> Image.Image:
316
  """Create gradient backgrounds"""
317
  try:
 
350
 
351
  except Exception as e:
352
  logger.error(f"Error creating gradient: {e}")
 
353
  img = np.full((height, width, 3), [70, 130, 180], dtype=np.uint8)
354
  return Image.fromarray(img)
355
 
 
370
  img = np.full((height, width, 3), rgb, dtype=np.uint8)
371
  return Image.fromarray(img)
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  def generate_ai_background(prompt: str) -> Tuple[Optional[Image.Image], str]:
374
  """Generate AI background using Hugging Face Inference API"""
375
  try:
376
  if not prompt.strip():
377
  return None, "Please enter a prompt"
378
 
 
379
  models = [
380
+ "black-forest-labs/FLUX.1-schnell",
381
+ "stabilityai/stable-diffusion-xl-base-1.0",
382
+ "runwayml/stable-diffusion-v1-5"
383
  ]
384
 
 
385
  enhanced_prompt = f"professional video background, {prompt}, high quality, 16:9 aspect ratio, cinematic lighting, detailed"
386
 
387
  for model in models:
388
  try:
389
+ logger.info(f"Trying AI generation with {model}...")
390
 
 
391
  api_url = f"https://api-inference.huggingface.co/models/{model}"
 
392
  headers = {
393
  "Authorization": f"Bearer {os.getenv('HUGGINGFACE_TOKEN', 'hf_placeholder')}"
394
  }
 
395
  payload = {
396
  "inputs": enhanced_prompt,
397
  "parameters": {
398
  "width": 1024,
399
+ "height": 576,
400
  "num_inference_steps": 20,
401
  "guidance_scale": 7.5
402
  }
 
405
  response = requests.post(api_url, headers=headers, json=payload, timeout=30)
406
 
407
  if response.status_code == 200:
 
408
  image = Image.open(io.BytesIO(response.content))
409
+ logger.info(f"AI background generated successfully with {model}")
410
+ return image, f"AI background generated: {prompt}"
 
411
  elif response.status_code == 503:
412
+ logger.warning(f"Model {model} is loading, trying next...")
 
413
  continue
 
414
  else:
415
+ logger.warning(f"Error with {model}: {response.status_code}")
416
  continue
417
 
418
  except Exception as e:
419
+ logger.warning(f"Error with {model}: {e}")
420
  continue
421
 
422
+ logger.info("AI generation failed, creating intelligent gradient fallback...")
423
+ return create_intelligent_gradient(prompt), f"Created gradient background inspired by: {prompt}"
 
424
 
425
  except Exception as e:
426
  logger.error(f"Error in AI background generation: {e}")
 
430
  """Create intelligent gradient based on prompt analysis"""
431
  prompt_lower = prompt.lower()
432
 
 
433
  if any(word in prompt_lower for word in ["sunset", "orange", "warm", "fire", "autumn"]):
434
  return create_gradient_background("sunset", 1920, 1080)
435
  elif any(word in prompt_lower for word in ["ocean", "sea", "blue", "water", "sky", "calm"]):
436
  return create_gradient_background("ocean", 1920, 1080)
437
  elif any(word in prompt_lower for word in ["forest", "green", "nature", "trees", "jungle"]):
438
  return create_gradient_background("forest", 1920, 1080)
 
 
 
 
439
  else:
440
  return create_gradient_background("default", 1920, 1080)
441
 
442
+ def process_video_with_background_stoppable(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  input_video: Optional[str],
444
  myavatar_selection: str,
445
  background_type: str,
 
447
  solid_color: str,
448
  custom_background: Optional[str],
449
  ai_prompt: str
450
+ ):
451
+ """Main processing function with stop capability"""
452
+ global processing_active
453
+ processing_active = True
454
+
455
  try:
456
+ # Show stop button, hide process button
457
+ yield gr.update(visible=False), gr.update(visible=True), None, "Starting processing..."
458
+
459
+ # Determine video source
460
  video_path = None
461
  if input_video:
462
  video_path = input_video
 
464
  elif myavatar_selection and myavatar_selection != "No videos available":
465
  video_url = myavatar_api.get_video_url(myavatar_selection)
466
  if video_url:
 
467
  response = requests.get(video_url)
468
  if response.status_code == 200:
469
  temp_video = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
 
471
  temp_video.close()
472
  video_path = temp_video.name
473
  logger.info("Using MyAvatar video")
 
 
 
 
474
 
475
  if not video_path:
476
+ yield gr.update(visible=True), gr.update(visible=False), None, "No video provided"
477
+ return
 
 
 
478
 
479
+ # Generate background
480
+ yield gr.update(visible=False), gr.update(visible=True), None, "Generating background..."
481
 
 
 
 
 
 
482
  background_image = None
 
483
  if background_type == "gradient":
484
  background_image = create_gradient_background(gradient_type, 1920, 1080)
485
  elif background_type == "solid":
 
488
  background_image = Image.open(custom_background)
489
  elif background_type == "ai" and ai_prompt:
490
  bg_img, ai_msg = generate_ai_background(ai_prompt)
491
+ background_image = bg_img
 
 
 
492
 
493
  if not background_image:
494
+ yield gr.update(visible=True), gr.update(visible=False), None, "No background generated"
495
+ return
496
+
497
+ # Process video
498
+ yield gr.update(visible=False), gr.update(visible=True), None, "Processing video with GPU optimization..."
499
+
500
+ bg_array = np.array(background_image.resize((1280, 720), Image.Resampling.LANCZOS))
501
 
502
+ with tempfile.NamedTemporaryFile(suffix='_processed.mp4', delete=False) as tmp_final:
 
 
503
  final_video_path = tmp_final.name
504
 
505
+ result_path = process_video_gpu_optimized(video_path, bg_array, final_video_path)
506
 
507
+ # Cleanup
508
  try:
509
+ if video_path != input_video:
 
510
  os.unlink(video_path)
511
  except:
512
  pass
513
 
514
+ if result_path and processing_active:
515
+ yield gr.update(visible=True), gr.update(visible=False), result_path, "Video processing completed successfully!"
 
516
  else:
517
+ yield gr.update(visible=True), gr.update(visible=False), None, "Processing was stopped or failed"
518
 
519
  except Exception as e:
520
+ logger.error(f"Error in video processing: {e}")
521
+ yield gr.update(visible=True), gr.update(visible=False), None, f"Processing error: {str(e)}"
522
+ finally:
523
+ processing_active = False
 
 
524
 
525
  def create_interface():
526
  """Create the Gradio interface"""
527
  logger.info("Creating Gradio interface...")
528
+ logger.info(f"Device: {device} | GPU: {GPU_NAME} | Memory: {GPU_MEMORY:.1f}GB")
529
 
 
530
  css = """
531
  .main-container { max-width: 1200px; margin: 0 auto; }
532
  .status-box { border: 2px solid #4CAF50; border-radius: 10px; padding: 15px; }
533
  .gradient-preview { border: 2px solid #ddd; border-radius: 10px; }
534
  """
535
 
536
+ with gr.Blocks(css=css, title="BackgroundFX Pro - GPU Optimized") as app:
537
 
 
538
  gr.Markdown("""
539
+ # BackgroundFX Pro - GPU Optimized
540
+ ### Professional Video Background Replacement with SAM2 + MatAnyone
541
  """)
542
 
 
543
  with gr.Row():
544
+ sam2_status = "Ready" if SAM2_ENABLED else "Disabled"
545
+ matany_status = "Ready" if MATANY_ENABLED else "Disabled"
546
  gr.Markdown(f"""
547
+ **System Status:** Online | **GPU:** {GPU_NAME} | **SAM2:** {sam2_status} | **MatAnyone:** {matany_status}
548
  """)
549
 
 
550
  with gr.Row():
 
551
  with gr.Column(scale=1):
552
+ gr.Markdown("## Video Input")
553
 
554
  with gr.Tabs():
555
+ with gr.Tab("Upload Video"):
556
  video_upload = gr.Video(label="Upload Video File", height=300)
557
 
558
+ with gr.Tab("MyAvatar Videos"):
559
+ refresh_btn = gr.Button("Refresh Videos", size="sm")
560
  myavatar_dropdown = gr.Dropdown(
561
  label="Select MyAvatar Video",
562
  choices=["Click refresh to load videos"],
 
564
  )
565
  video_preview = gr.Video(label="Preview", height=200)
566
 
567
+ gr.Markdown("## Background Options")
568
 
569
  background_type = gr.Radio(
570
  choices=["gradient", "solid", "custom", "ai"],
 
573
  )
574
 
575
  with gr.Group():
 
576
  gradient_type = gr.Dropdown(
577
  choices=["sunset", "ocean", "forest", "default"],
578
  value="sunset",
 
581
  )
582
  gradient_preview = gr.Image(label="Gradient Preview", height=150)
583
 
 
584
  solid_color = gr.Dropdown(
585
  choices=["white", "black", "blue", "green", "red", "purple", "orange", "yellow"],
586
  value="blue",
 
589
  )
590
  color_preview = gr.Image(label="Color Preview", height=150, visible=False)
591
 
 
592
  custom_bg_upload = gr.Image(
593
  label="Upload Custom Background",
594
  type="filepath",
595
  visible=False
596
  )
597
 
 
598
  ai_prompt = gr.Textbox(
599
  label="AI Background Prompt",
600
  placeholder="Describe the background you want...",
601
  visible=False
602
  )
603
+ ai_generate_btn = gr.Button("Generate AI Background", visible=False)
604
  ai_preview = gr.Image(label="AI Generated Background", height=150, visible=False)
605
 
606
+ with gr.Row():
607
+ process_btn = gr.Button("Process Video", variant="primary", size="lg")
608
+ stop_btn = gr.Button("Stop Processing", variant="stop", size="lg", visible=False)
609
 
 
610
  with gr.Column(scale=1):
611
+ gr.Markdown("## Results")
612
 
613
  result_video = gr.Video(label="Processed Video", height=400)
614
 
 
619
  elem_classes=["status-box"]
620
  )
621
 
 
622
  gr.Markdown("""
623
+ ### Processing Pipeline:
624
+ 1. **SAM2 Segmentation** - GPU-accelerated person detection
625
+ 2. **MatAnyone Matting** - Professional temporal consistency
626
+ 3. **GPU Compositing** - Real-time background replacement
627
+ 4. **Memory Optimization** - Chunked processing for efficiency
628
 
629
+ **Performance:** ~3-5 minutes per 1000 frames on T4 GPU
630
  """)
631
 
632
+ # Event handlers
 
633
  def update_background_options(bg_type):
 
634
  return {
635
  gradient_type: gr.update(visible=(bg_type == "gradient")),
636
  gradient_preview: gr.update(visible=(bg_type == "gradient")),
 
643
  }
644
 
645
  def update_gradient_preview(grad_type):
 
646
  try:
647
+ return create_gradient_background(grad_type, 400, 200)
 
648
  except:
649
  return None
650
 
651
  def update_color_preview(color):
 
652
  try:
653
+ return create_solid_color(color, 400, 200)
 
654
  except:
655
  return None
656
 
657
  def refresh_myavatar_videos():
 
658
  try:
659
  choices = myavatar_api.get_video_choices()
660
  return gr.update(choices=choices, value=None)
 
663
  return gr.update(choices=["Error loading videos"])
664
 
665
  def load_video_preview(selection):
 
666
  try:
667
  if not selection or selection == "No videos available":
668
  return None
 
674
  return None
675
 
676
  def generate_ai_bg(prompt):
 
677
  bg_img, message = generate_ai_background(prompt)
678
  return bg_img
679
 
 
715
  )
716
 
717
  process_btn.click(
718
+ fn=process_video_with_background_stoppable,
719
  inputs=[
720
  video_upload,
721
  myavatar_dropdown,
 
725
  custom_bg_upload,
726
  ai_prompt
727
  ],
728
+ outputs=[process_btn, stop_btn, result_video, status_output]
729
+ )
730
+
731
+ stop_btn.click(
732
+ fn=stop_processing,
733
+ outputs=[stop_btn, status_output]
734
  )
735
 
 
736
  app.load(
737
  fn=lambda: create_gradient_background("sunset", 400, 200),
738
  outputs=[gradient_preview]
 
740
 
741
  return app
742
 
 
 
 
 
743
  def main():
744
  """Main application entry point"""
745
  try:
746
+ # Pre-warm models
747
+ logger.info("Pre-warming GPU models...")
748
+ if SAM2_ENABLED:
749
+ get_sam2()
750
+ if MATANY_ENABLED:
751
+ get_matanyone()
752
 
 
753
  app = create_interface()
754
 
755
  app.launch(
 
761
  )
762
 
763
  except Exception as e:
764
+ logger.error(f"Failed to start application: {e}")
765
  sys.exit(1)
766
 
767
  if __name__ == "__main__":