|
|
|
|
|
""" |
|
|
BackgroundFX Pro - SAM2 + MatAnyone Professional Video Background Replacer |
|
|
State-of-the-art video background replacement with professional alpha matting |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import tempfile |
|
|
import os |
|
|
from PIL import Image |
|
|
import requests |
|
|
from io import BytesIO |
|
|
import logging |
|
|
import gc |
|
|
import torch |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
MAX_VIDEO_DURATION = 300 |
|
|
SUPPORTED_VIDEO_FORMATS = ['.mp4', '.avi', '.mov', '.mkv', '.webm'] |
|
|
|
|
|
|
|
|
def setup_gpu(): |
|
|
"""Setup GPU with detailed information and optimization""" |
|
|
if torch.cuda.is_available(): |
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 |
|
|
torch.cuda.init() |
|
|
torch.cuda.set_device(0) |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
|
gpu_optimizations = { |
|
|
"T4": {"use_half": True, "batch_size": 1}, |
|
|
"V100": {"use_half": False, "batch_size": 2}, |
|
|
"A10": {"use_half": True, "batch_size": 2}, |
|
|
"A100": {"use_half": False, "batch_size": 4} |
|
|
} |
|
|
|
|
|
gpu_type = None |
|
|
for gpu in gpu_optimizations: |
|
|
if gpu in gpu_name: |
|
|
gpu_type = gpu |
|
|
break |
|
|
|
|
|
return True, gpu_name, gpu_memory, gpu_type |
|
|
return False, None, 0, None |
|
|
|
|
|
CUDA_AVAILABLE, GPU_NAME, GPU_MEMORY, GPU_TYPE = setup_gpu() |
|
|
DEVICE = 'cuda' if CUDA_AVAILABLE else 'cpu' |
|
|
|
|
|
logger.info(f"Device: {DEVICE} | GPU: {GPU_NAME} | Memory: {GPU_MEMORY:.1f}GB | Type: {GPU_TYPE}") |
|
|
|
|
|
|
|
|
class SAM2WithPersonDetection: |
|
|
def __init__(self): |
|
|
self.predictor = None |
|
|
self.current_model_size = None |
|
|
self.person_detector = None |
|
|
self.model_cache_dir = Path(tempfile.gettempdir()) / "sam2_cache" |
|
|
self.model_cache_dir.mkdir(exist_ok=True) |
|
|
|
|
|
self.models = { |
|
|
"tiny": { |
|
|
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt", |
|
|
"config": "sam2_hiera_t.yaml", |
|
|
"size_mb": 38, |
|
|
"description": "Fastest, lowest memory" |
|
|
}, |
|
|
"small": { |
|
|
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt", |
|
|
"config": "sam2_hiera_s.yaml", |
|
|
"size_mb": 185, |
|
|
"description": "Balanced speed/quality" |
|
|
}, |
|
|
"base": { |
|
|
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt", |
|
|
"config": "sam2_hiera_b+.yaml", |
|
|
"size_mb": 320, |
|
|
"description": "Best quality, slower" |
|
|
} |
|
|
} |
|
|
|
|
|
def get_model_path(self, model_size): |
|
|
"""Get cached model path""" |
|
|
model_name = f"sam2_{model_size}.pt" |
|
|
return self.model_cache_dir / model_name |
|
|
|
|
|
def clear_model(self): |
|
|
"""Clear current model from memory""" |
|
|
if self.predictor: |
|
|
del self.predictor |
|
|
self.predictor = None |
|
|
self.current_model_size = None |
|
|
|
|
|
if self.person_detector: |
|
|
del self.person_detector |
|
|
self.person_detector = None |
|
|
|
|
|
if CUDA_AVAILABLE: |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
logger.info("SAM2 model and person detector cleared from memory") |
|
|
|
|
|
def load_person_detector(self, progress_fn=None): |
|
|
"""Load lightweight person detector""" |
|
|
if self.person_detector is not None: |
|
|
return self.person_detector |
|
|
|
|
|
try: |
|
|
if progress_fn: |
|
|
progress_fn(0.05, "Loading person detector...") |
|
|
|
|
|
|
|
|
import cv2 |
|
|
|
|
|
|
|
|
|
|
|
self.person_detector = cv2.createBackgroundSubtractorMOG2(detectShadows=True) |
|
|
|
|
|
if progress_fn: |
|
|
progress_fn(0.1, "Person detector loaded!") |
|
|
|
|
|
logger.info("Person detector loaded successfully") |
|
|
return self.person_detector |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load person detector: {e}") |
|
|
self.person_detector = None |
|
|
return None |
|
|
|
|
|
def detect_person_bbox(self, image, progress_fn=None): |
|
|
"""Detect person bounding box in image""" |
|
|
try: |
|
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
|
|
|
blurred = cv2.GaussianBlur(gray, (5, 5), 0) |
|
|
|
|
|
|
|
|
edges = cv2.Canny(blurred, 50, 150) |
|
|
|
|
|
|
|
|
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
if not contours: |
|
|
return None |
|
|
|
|
|
|
|
|
largest_contour = max(contours, key=cv2.contourArea) |
|
|
|
|
|
|
|
|
x, y, w, h = cv2.boundingRect(largest_contour) |
|
|
|
|
|
|
|
|
image_area = image.shape[0] * image.shape[1] |
|
|
bbox_area = w * h |
|
|
|
|
|
|
|
|
if bbox_area < image_area * 0.05 or bbox_area > image_area * 0.8: |
|
|
return None |
|
|
|
|
|
|
|
|
if h < w * 0.8: |
|
|
return None |
|
|
|
|
|
return [x, y, x + w, y + h] |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Person detection failed: {e}") |
|
|
return None |
|
|
|
|
|
def get_smart_points_from_bbox(self, bbox, image_shape): |
|
|
"""Generate smart points within person bounding box""" |
|
|
if bbox is None: |
|
|
|
|
|
h, w = image_shape[:2] |
|
|
return [ |
|
|
[w//4, h//3], [w//2, h//3], [3*w//4, h//3], |
|
|
[w//4, h//2], [w//2, h//2], [3*w//4, h//2], |
|
|
[w//4, 2*h//3], [w//2, 2*h//3], [3*w//4, 2*h//3] |
|
|
] |
|
|
|
|
|
x1, y1, x2, y2 = bbox |
|
|
center_x = (x1 + x2) // 2 |
|
|
center_y = (y1 + y2) // 2 |
|
|
width = x2 - x1 |
|
|
height = y2 - y1 |
|
|
|
|
|
|
|
|
points = [ |
|
|
[center_x, center_y], |
|
|
[center_x, y1 + height//4], |
|
|
[center_x, y1 + height//2], |
|
|
[center_x, y1 + 3*height//4], |
|
|
[x1 + width//4, center_y], |
|
|
[x2 - width//4, center_y], |
|
|
[center_x - width//6, y1 + height//3], |
|
|
[center_x + width//6, y1 + height//3], |
|
|
] |
|
|
|
|
|
return points |
|
|
|
|
|
def download_model(self, model_size, progress_fn=None): |
|
|
"""Download model with progress tracking and verification""" |
|
|
model_info = self.models[model_size] |
|
|
model_path = self.get_model_path(model_size) |
|
|
|
|
|
if model_path.exists(): |
|
|
logger.info(f"Model {model_size} already cached") |
|
|
return model_path |
|
|
|
|
|
try: |
|
|
logger.info(f"Downloading SAM2 {model_size} model...") |
|
|
response = requests.get(model_info['url'], stream=True) |
|
|
response.raise_for_status() |
|
|
|
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
downloaded = 0 |
|
|
|
|
|
with open(model_path, 'wb') as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
if chunk: |
|
|
f.write(chunk) |
|
|
downloaded += len(chunk) |
|
|
if progress_fn and total_size > 0: |
|
|
progress = downloaded / total_size * 0.15 |
|
|
progress_fn(0.1 + progress, f"Downloading SAM2 {model_size} ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)") |
|
|
|
|
|
logger.info(f"SAM2 {model_size} downloaded successfully") |
|
|
return model_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to download SAM2 {model_size}: {e}") |
|
|
if model_path.exists(): |
|
|
model_path.unlink() |
|
|
raise |
|
|
|
|
|
def load_model(self, model_size, progress_fn=None): |
|
|
"""Load SAM2 model with optimization""" |
|
|
try: |
|
|
|
|
|
self.load_person_detector(progress_fn) |
|
|
|
|
|
|
|
|
try: |
|
|
from sam2.build_sam import build_sam2 |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
except ImportError as e: |
|
|
logger.error("SAM2 not available. Install with: pip install segment-anything-2") |
|
|
raise ImportError("SAM2 package not found") from e |
|
|
|
|
|
model_path = self.download_model(model_size, progress_fn) |
|
|
|
|
|
if progress_fn: |
|
|
progress_fn(0.25, f"Loading SAM2 {model_size} model...") |
|
|
|
|
|
|
|
|
model_config = self.models[model_size]["config"] |
|
|
sam2_model = build_sam2(model_config, str(model_path), device=DEVICE) |
|
|
|
|
|
|
|
|
if CUDA_AVAILABLE and GPU_TYPE in ["T4", "A10"]: |
|
|
sam2_model = sam2_model.half() |
|
|
logger.info(f"Applied half precision for {GPU_TYPE}") |
|
|
|
|
|
self.predictor = SAM2ImagePredictor(sam2_model) |
|
|
self.current_model_size = model_size |
|
|
|
|
|
if progress_fn: |
|
|
progress_fn(0.3, f"SAM2 {model_size} with person detection ready!") |
|
|
|
|
|
logger.info(f"SAM2 {model_size} model with person detection loaded and ready") |
|
|
return self.predictor |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load SAM2 {model_size}: {e}") |
|
|
self.clear_model() |
|
|
raise |
|
|
|
|
|
def get_predictor(self, model_size="small", progress_fn=None): |
|
|
"""Get predictor, loading if necessary""" |
|
|
if self.predictor is None or self.current_model_size != model_size: |
|
|
self.clear_model() |
|
|
return self.load_model(model_size, progress_fn) |
|
|
return self.predictor |
|
|
|
|
|
def segment_image_smart(self, image, model_size="small", progress_fn=None): |
|
|
"""Smart segmentation: Find person first, then segment""" |
|
|
predictor = self.get_predictor(model_size, progress_fn) |
|
|
|
|
|
try: |
|
|
if progress_fn: |
|
|
progress_fn(0.32, "Finding person in image...") |
|
|
|
|
|
|
|
|
person_bbox = self.detect_person_bbox(image, progress_fn) |
|
|
|
|
|
if progress_fn: |
|
|
if person_bbox: |
|
|
progress_fn(0.35, f"Person found! Segmenting with high precision...") |
|
|
else: |
|
|
progress_fn(0.35, f"Using grid search for segmentation...") |
|
|
|
|
|
|
|
|
smart_points = self.get_smart_points_from_bbox(person_bbox, image.shape) |
|
|
|
|
|
|
|
|
predictor.set_image(image) |
|
|
|
|
|
point_coords = np.array(smart_points) |
|
|
point_labels = np.ones(len(point_coords)) |
|
|
|
|
|
if progress_fn: |
|
|
progress_fn(0.38, f"SAM2 segmenting with {len(smart_points)} smart points...") |
|
|
|
|
|
masks, scores, logits = predictor.predict( |
|
|
point_coords=point_coords, |
|
|
point_labels=point_labels, |
|
|
multimask_output=True |
|
|
) |
|
|
|
|
|
|
|
|
best_mask_idx = scores.argmax() |
|
|
best_mask = masks[best_mask_idx] |
|
|
best_score = scores[best_mask_idx] |
|
|
|
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) |
|
|
best_mask = cv2.morphologyEx(best_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel) |
|
|
|
|
|
|
|
|
best_mask = cv2.GaussianBlur(best_mask.astype(np.float32), (3, 3), 1.0) |
|
|
|
|
|
|
|
|
if person_bbox and best_score > 0.3: |
|
|
best_score = min(best_score * 1.5, 1.0) |
|
|
|
|
|
logger.info(f"Smart segmentation complete: confidence={best_score:.3f}, person_detected={person_bbox is not None}") |
|
|
|
|
|
return best_mask, float(best_score) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Smart segmentation failed: {e}") |
|
|
return None, 0.0 |
|
|
|
|
|
|
|
|
class MatAnyoneLazy: |
|
|
def __init__(self): |
|
|
self.processor = None |
|
|
self.available = False |
|
|
|
|
|
def load_model(self, progress_fn=None): |
|
|
"""Load MatAnyone model lazily""" |
|
|
if self.processor is not None: |
|
|
return self.processor |
|
|
|
|
|
try: |
|
|
if progress_fn: |
|
|
progress_fn(0.3, "Loading MatAnyone professional matting...") |
|
|
|
|
|
|
|
|
try: |
|
|
from matanyone import InferenceCore |
|
|
|
|
|
|
|
|
self.processor = InferenceCore("PeiqingYang/MatAnyone") |
|
|
self.available = True |
|
|
|
|
|
if progress_fn: |
|
|
progress_fn(0.4, "MatAnyone loaded successfully!") |
|
|
|
|
|
logger.info("MatAnyone model loaded for professional video matting") |
|
|
return self.processor |
|
|
|
|
|
except ImportError as e: |
|
|
logger.warning(f"MatAnyone not available: {e}") |
|
|
self.available = False |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load MatAnyone: {e}") |
|
|
self.available = False |
|
|
return None |
|
|
|
|
|
def process_video_with_mask(self, video_path, mask_path, progress_fn=None): |
|
|
"""Process video with MatAnyone using mask from SAM2""" |
|
|
if not self.available: |
|
|
return None, None |
|
|
|
|
|
try: |
|
|
processor = self.load_model(progress_fn) |
|
|
if processor is None: |
|
|
return None, None |
|
|
|
|
|
if progress_fn: |
|
|
progress_fn(0.5, "MatAnyone processing video...") |
|
|
|
|
|
|
|
|
foreground_path, alpha_path = processor.process_video( |
|
|
input_path=video_path, |
|
|
mask_path=mask_path |
|
|
) |
|
|
|
|
|
if progress_fn: |
|
|
progress_fn(0.8, "MatAnyone processing complete!") |
|
|
|
|
|
return foreground_path, alpha_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"MatAnyone processing failed: {e}") |
|
|
return None, None |
|
|
|
|
|
def clear_model(self): |
|
|
"""Clear MatAnyone model from memory""" |
|
|
if self.processor: |
|
|
del self.processor |
|
|
self.processor = None |
|
|
if CUDA_AVAILABLE: |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
class SAM2MatAnyonePipeline: |
|
|
def __init__(self): |
|
|
self.sam2_loader = SAM2WithPersonDetection() |
|
|
self.matanyone_loader = MatAnyoneLazy() |
|
|
|
|
|
def clear_models(self): |
|
|
"""Clear all models from memory""" |
|
|
self.sam2_loader.clear_model() |
|
|
self.matanyone_loader.clear_model() |
|
|
|
|
|
if CUDA_AVAILABLE: |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
logger.info("All models cleared from memory") |
|
|
|
|
|
|
|
|
professional_pipeline = SAM2MatAnyonePipeline() |
|
|
|
|
|
|
|
|
def validate_video(video_path): |
|
|
"""Comprehensive video validation""" |
|
|
if not video_path or not os.path.exists(video_path): |
|
|
return False, "No video file provided" |
|
|
|
|
|
|
|
|
file_ext = Path(video_path).suffix.lower() |
|
|
if file_ext not in SUPPORTED_VIDEO_FORMATS: |
|
|
return False, f"Unsupported format. Supported: {', '.join(SUPPORTED_VIDEO_FORMATS)}" |
|
|
|
|
|
try: |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
return False, "Cannot open video file" |
|
|
|
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
cap.release() |
|
|
|
|
|
if fps <= 0 or frame_count <= 0: |
|
|
return False, "Invalid video properties" |
|
|
|
|
|
duration = frame_count / fps |
|
|
|
|
|
|
|
|
if duration > MAX_VIDEO_DURATION: |
|
|
return False, f"Video too long ({duration:.1f}s). Max: {MAX_VIDEO_DURATION}s" |
|
|
|
|
|
|
|
|
if width * height > 1920 * 1080: |
|
|
return False, "Resolution too high (max 1920x1080)" |
|
|
|
|
|
return True, f"Valid video: {duration:.1f}s, {width}x{height}, {fps:.1f}fps" |
|
|
|
|
|
except Exception as e: |
|
|
return False, f"Video validation error: {str(e)}" |
|
|
|
|
|
|
|
|
def create_gradient_background(width=1280, height=720, color1=(70, 130, 180), color2=(255, 140, 90)): |
|
|
"""Create smooth gradient background""" |
|
|
background = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
for y in range(height): |
|
|
ratio = y / height |
|
|
r = int(color1[0] * (1 - ratio) + color2[0] * ratio) |
|
|
g = int(color1[1] * (1 - ratio) + color2[1] * ratio) |
|
|
b = int(color1[2] * (1 - ratio) + color2[2] * ratio) |
|
|
background[y, :] = [r, g, b] |
|
|
return background |
|
|
|
|
|
def get_background_presets(): |
|
|
"""Get available background presets""" |
|
|
return { |
|
|
"gradient:ocean": ("π Ocean Blue", (20, 120, 180), (135, 206, 235)), |
|
|
"gradient:sunset": ("π
Sunset Orange", (255, 94, 77), (255, 154, 0)), |
|
|
"gradient:forest": ("π² Forest Green", (34, 139, 34), (144, 238, 144)), |
|
|
"gradient:purple": ("π Purple Haze", (128, 0, 128), (221, 160, 221)), |
|
|
"color:white": ("βͺ Pure White", None, None), |
|
|
"color:black": ("β« Pure Black", None, None), |
|
|
"color:green": ("π Chroma Green", None, None), |
|
|
"color:blue": ("π Chroma Blue", None, None) |
|
|
} |
|
|
|
|
|
def create_background_from_preset(preset, width, height): |
|
|
"""Create background from preset""" |
|
|
presets = get_background_presets() |
|
|
|
|
|
if preset not in presets: |
|
|
return create_gradient_background(width, height) |
|
|
|
|
|
name, color1, color2 = presets[preset] |
|
|
|
|
|
if preset.startswith("gradient:"): |
|
|
return create_gradient_background(width, height, color1, color2) |
|
|
elif preset.startswith("color:"): |
|
|
color_map = { |
|
|
"white": [255, 255, 255], |
|
|
"black": [0, 0, 0], |
|
|
"green": [0, 255, 0], |
|
|
"blue": [0, 0, 255] |
|
|
} |
|
|
color_name = preset.split(":")[1] |
|
|
color = color_map.get(color_name, [255, 255, 255]) |
|
|
return np.full((height, width, 3), color, dtype=np.uint8) |
|
|
|
|
|
def load_background_image(background_img, background_preset, target_width, target_height): |
|
|
"""Load and prepare background image""" |
|
|
try: |
|
|
if background_img is not None: |
|
|
|
|
|
background = np.array(background_img.convert('RGB')) |
|
|
else: |
|
|
|
|
|
background = create_background_from_preset(background_preset, target_width, target_height) |
|
|
|
|
|
|
|
|
if background.shape[:2] != (target_height, target_width): |
|
|
background = cv2.resize(background, (target_width, target_height)) |
|
|
|
|
|
return background |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Background loading failed: {e}") |
|
|
return create_gradient_background(target_width, target_height) |
|
|
|
|
|
|
|
|
def process_video_professional(input_video, background_img, background_preset, model_size, |
|
|
edge_smoothing, use_matanyone, progress=gr.Progress()): |
|
|
"""Professional video processing with SAM2 + MatAnyone pipeline""" |
|
|
|
|
|
if input_video is None: |
|
|
return None, "β Please upload a video file" |
|
|
|
|
|
|
|
|
progress(0.02, desc="Validating video...") |
|
|
is_valid, validation_msg = validate_video(input_video) |
|
|
if not is_valid: |
|
|
return None, f"β {validation_msg}" |
|
|
|
|
|
logger.info(f"Video validation: {validation_msg}") |
|
|
|
|
|
try: |
|
|
|
|
|
progress(0.05, desc="Reading video properties...") |
|
|
cap = cv2.VideoCapture(input_video) |
|
|
|
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
duration = total_frames / fps if fps > 0 else 0 |
|
|
|
|
|
cap.release() |
|
|
|
|
|
logger.info(f"Video: {width}x{height}, {fps}fps, {total_frames} frames, {duration:.1f}s") |
|
|
|
|
|
|
|
|
progress(0.08, desc="Preparing background...") |
|
|
background_image = load_background_image(background_img, background_preset, width, height) |
|
|
|
|
|
if use_matanyone: |
|
|
|
|
|
progress(0.1, desc="Starting SAM2 + MatAnyone professional pipeline...") |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(input_video) |
|
|
ret, first_frame = cap.read() |
|
|
cap.release() |
|
|
|
|
|
if not ret: |
|
|
return None, "β Cannot read first frame" |
|
|
|
|
|
|
|
|
def sam2_progress(prog, msg): |
|
|
progress(0.1 + prog * 0.15, desc=msg) |
|
|
|
|
|
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
|
|
mask, confidence = professional_pipeline.sam2_loader.segment_image_smart( |
|
|
first_frame_rgb, model_size, sam2_progress |
|
|
) |
|
|
|
|
|
if mask is None or confidence < 0.3: |
|
|
return None, f"β SAM2 segmentation failed (confidence: {confidence:.2f})" |
|
|
|
|
|
|
|
|
temp_mask_path = tempfile.mktemp(suffix='.png') |
|
|
mask_uint8 = (mask * 255).astype(np.uint8) |
|
|
cv2.imwrite(temp_mask_path, mask_uint8) |
|
|
|
|
|
|
|
|
def matanyone_progress(prog, msg): |
|
|
progress(0.25 + prog * 0.5, desc=msg) |
|
|
|
|
|
foreground_path, alpha_path = professional_pipeline.matanyone_loader.process_video_with_mask( |
|
|
input_video, temp_mask_path, matanyone_progress |
|
|
) |
|
|
|
|
|
|
|
|
if os.path.exists(temp_mask_path): |
|
|
os.unlink(temp_mask_path) |
|
|
|
|
|
if foreground_path is None: |
|
|
|
|
|
return process_video_sam2_only(input_video, background_image, model_size, edge_smoothing, progress) |
|
|
|
|
|
|
|
|
progress(0.8, desc="Compositing with new background...") |
|
|
output_path = composite_matanyone_result(foreground_path, alpha_path, background_image, fps) |
|
|
|
|
|
else: |
|
|
|
|
|
output_path = process_video_sam2_only(input_video, background_image, model_size, edge_smoothing, progress) |
|
|
|
|
|
|
|
|
professional_pipeline.clear_models() |
|
|
|
|
|
if CUDA_AVAILABLE: |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
|
|
quality_info = "Professional MatAnyone" if use_matanyone else "Standard SAM2" |
|
|
return output_path, f"β
{quality_info} processing: {duration:.1f}s video completed successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"β Processing failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
professional_pipeline.clear_models() |
|
|
return None, error_msg |
|
|
|
|
|
def process_video_sam2_only(input_video, background_image, model_size, edge_smoothing, progress): |
|
|
"""SAM2-only processing pipeline""" |
|
|
cap = cv2.VideoCapture(input_video) |
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
output_path = tempfile.mktemp(suffix='.mp4') |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
frame_count = 0 |
|
|
last_alpha = None |
|
|
|
|
|
def sam2_progress(prog, msg): |
|
|
overall_prog = 0.3 + (prog * 0.2) |
|
|
progress(overall_prog, desc=msg) |
|
|
|
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
alpha, confidence = professional_pipeline.sam2_loader.segment_image_smart( |
|
|
frame_rgb, model_size, sam2_progress |
|
|
) |
|
|
|
|
|
if alpha is not None and confidence > 0.3: |
|
|
current_alpha = alpha |
|
|
last_alpha = current_alpha |
|
|
else: |
|
|
if last_alpha is not None: |
|
|
current_alpha = last_alpha |
|
|
else: |
|
|
|
|
|
current_alpha = np.ones((height, width), dtype=np.float32) * 0.8 |
|
|
|
|
|
|
|
|
if edge_smoothing > 0: |
|
|
kernel_size = int(edge_smoothing * 2) + 1 |
|
|
current_alpha = cv2.GaussianBlur(current_alpha, (kernel_size, kernel_size), edge_smoothing) |
|
|
|
|
|
|
|
|
if current_alpha.ndim == 2: |
|
|
alpha_channel = np.expand_dims(current_alpha, axis=2) |
|
|
else: |
|
|
alpha_channel = current_alpha |
|
|
|
|
|
alpha_channel = np.clip(alpha_channel, 0, 1) |
|
|
foreground = frame_rgb.astype(np.float32) |
|
|
background = background_image.astype(np.float32) |
|
|
|
|
|
composite = foreground * alpha_channel + background * (1 - alpha_channel) |
|
|
composite = np.clip(composite, 0, 255).astype(np.uint8) |
|
|
|
|
|
composite_bgr = cv2.cvtColor(composite, cv2.COLOR_RGB2BGR) |
|
|
out.write(composite_bgr) |
|
|
|
|
|
frame_count += 1 |
|
|
|
|
|
if frame_count % 5 == 0: |
|
|
frame_progress = frame_count / total_frames |
|
|
overall_progress = 0.5 + (frame_progress * 0.4) |
|
|
progress(overall_progress, desc=f"SAM2 processing frame {frame_count}/{total_frames}") |
|
|
|
|
|
cap.release() |
|
|
out.release() |
|
|
|
|
|
return output_path |
|
|
|
|
|
def composite_matanyone_result(foreground_path, alpha_path, background_image, fps): |
|
|
"""Composite MatAnyone result with new background""" |
|
|
|
|
|
|
|
|
return foreground_path |
|
|
|
|
|
|
|
|
def create_professional_interface(): |
|
|
"""Create the professional Gradio interface with SAM2 + MatAnyone""" |
|
|
|
|
|
|
|
|
preset_choices = [("Custom (upload image)", "custom")] |
|
|
for key, (name, _, _) in get_background_presets().items(): |
|
|
preset_choices.append((name, key)) |
|
|
|
|
|
with gr.Blocks( |
|
|
title="BackgroundFX Pro - SAM2 + MatAnyone", |
|
|
theme=gr.themes.Soft(), |
|
|
css=""" |
|
|
.gradio-container { |
|
|
max-width: 1400px !important; |
|
|
} |
|
|
.main-header { |
|
|
text-align: center; |
|
|
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); |
|
|
-webkit-background-clip: text; |
|
|
-webkit-text-fill-color: transparent; |
|
|
background-clip: text; |
|
|
} |
|
|
.professional-badge { |
|
|
background: linear-gradient(45deg, #FFD700, #FFA500); |
|
|
color: black; |
|
|
padding: 8px 16px; |
|
|
border-radius: 20px; |
|
|
font-weight: bold; |
|
|
display: inline-block; |
|
|
margin: 10px 0; |
|
|
} |
|
|
""" |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# π¬ BackgroundFX Pro - SAM2 + MatAnyone |
|
|
**Professional AI video background replacement with state-of-the-art alpha matting** |
|
|
|
|
|
<div class="professional-badge">π Powered by SAM2 + MatAnyone (CVPR 2025)</div> |
|
|
|
|
|
Upload your video and experience Hollywood-quality background replacement with cutting-edge AI segmentation and professional alpha matting. |
|
|
""", elem_classes=["main-header"]) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### π€ Input Configuration") |
|
|
|
|
|
video_input = gr.Video( |
|
|
label="Upload Video (MP4, AVI, MOV, MKV, WebM - max 5 min)", |
|
|
height=300 |
|
|
) |
|
|
|
|
|
with gr.Tab("π¨ Background"): |
|
|
background_preset = gr.Dropdown( |
|
|
choices=preset_choices, |
|
|
value="gradient:ocean", |
|
|
label="Background Preset - Choose preset or upload custom image" |
|
|
) |
|
|
|
|
|
background_input = gr.Image( |
|
|
label="Custom Background (Upload image to override preset)", |
|
|
type="pil", |
|
|
height=200 |
|
|
) |
|
|
|
|
|
with gr.Accordion("π€ SAM2 Settings", open=True): |
|
|
model_size = gr.Radio( |
|
|
choices=[ |
|
|
("Tiny (38MB) - Fastest", "tiny"), |
|
|
("Small (185MB) - Balanced β", "small"), |
|
|
("Base (320MB) - Best Quality", "base") |
|
|
], |
|
|
value="small", |
|
|
label="SAM2 Model Size - Larger models = better segmentation but slower" |
|
|
) |
|
|
|
|
|
edge_smoothing = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=5, |
|
|
value=1.5, |
|
|
step=0.5, |
|
|
label="Edge Smoothing - Softens edges around subject (0=sharp, 5=very soft)" |
|
|
) |
|
|
|
|
|
with gr.Accordion("π MatAnyone Professional Settings", open=True): |
|
|
use_matanyone = gr.Checkbox( |
|
|
value=True, |
|
|
label="Enable MatAnyone Professional Alpha Matting - CVPR 2025 best quality but slower" |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
**Quality Comparison:** |
|
|
- β
**MatAnyone ON**: Professional hair/edge detail, cinema-quality results |
|
|
- β‘ **MatAnyone OFF**: Fast SAM2-only processing, good for previews |
|
|
""") |
|
|
|
|
|
process_btn = gr.Button( |
|
|
"π Create Professional Video", |
|
|
variant="primary", |
|
|
size="lg", |
|
|
scale=2 |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### π₯ Professional Output") |
|
|
|
|
|
video_output = gr.Video( |
|
|
label="Processed Video", |
|
|
height=400, |
|
|
show_download_button=True |
|
|
) |
|
|
|
|
|
status_output = gr.Textbox( |
|
|
label="Processing Status", |
|
|
lines=3, |
|
|
max_lines=5 |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### π‘ Professional Tips |
|
|
- **Best results**: Clear subject separation from background |
|
|
- **Lighting**: Even lighting eliminates edge artifacts |
|
|
- **Movement**: Steady shots for consistent quality |
|
|
- **MatAnyone**: Use for final videos, disable for quick previews |
|
|
- **Processing**: 90-180s per minute with MatAnyone ON |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
if CUDA_AVAILABLE: |
|
|
gr.Markdown(f"π **GPU Acceleration:** {GPU_NAME} ({GPU_MEMORY:.1f}GB) | Type: {GPU_TYPE}") |
|
|
else: |
|
|
gr.Markdown("π» **CPU Mode** (GPU recommended for MatAnyone)") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("π§ **AI Models:** SAM2 + MatAnyone | π¦ **Storage:** 0MB (True lazy loading)") |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
fn=process_video_professional, |
|
|
inputs=[ |
|
|
video_input, |
|
|
background_input, |
|
|
background_preset, |
|
|
model_size, |
|
|
edge_smoothing, |
|
|
use_matanyone |
|
|
], |
|
|
outputs=[video_output, status_output], |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown(""" |
|
|
### π¬ Professional Use Cases |
|
|
- **π― Content Creation**: Remove distracting backgrounds for professional videos |
|
|
- **πΉ Virtual Production**: Custom backgrounds for video calls and streaming |
|
|
- **π Education**: Clean, professional backgrounds for instructional content |
|
|
- **π± Social Media**: Eye-catching backgrounds that increase engagement |
|
|
- **πͺ Entertainment**: Creative backgrounds for artistic and commercial projects |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
logger.info("Starting BackgroundFX Pro with SAM2 + MatAnyone...") |
|
|
logger.info(f"Device: {DEVICE}") |
|
|
if CUDA_AVAILABLE: |
|
|
logger.info(f"GPU: {GPU_NAME} ({GPU_MEMORY:.1f}GB)") |
|
|
|
|
|
|
|
|
demo = create_professional_interface() |
|
|
|
|
|
demo.queue( |
|
|
max_size=5 |
|
|
).launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_error=True, |
|
|
quiet=False |
|
|
) |