""" Advanced matting algorithms for BackgroundFX Pro. Implements multiple matting techniques with automatic fallback. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import cv2 from typing import Dict, Tuple, Optional, List from dataclasses import dataclass import logging from ..utils.logger import setup_logger from ..utils.device import DeviceManager from ..utils.config import ConfigManager from ..core.models import ModelFactory, ModelType from ..core.quality import QualityAnalyzer from ..core.edge import EdgeRefinement logger = setup_logger(__name__) @dataclass class MattingConfig: """Configuration for matting operations.""" alpha_threshold: float = 0.5 erode_iterations: int = 2 dilate_iterations: int = 2 blur_radius: int = 3 trimap_size: int = 30 confidence_threshold: float = 0.7 use_guided_filter: bool = True guided_filter_radius: int = 8 guided_filter_eps: float = 1e-6 use_temporal_smoothing: bool = False temporal_window: int = 5 class AlphaMatting: """Advanced alpha matting using multiple techniques.""" def __init__(self, config: Optional[MattingConfig] = None): self.config = config or MattingConfig() self.device_manager = DeviceManager() self.quality_analyzer = QualityAnalyzer() self.edge_refinement = EdgeRefinement() def create_trimap(self, mask: np.ndarray, dilation_size: int = None) -> np.ndarray: """ Create trimap from binary mask. Args: mask: Binary mask (H, W) dilation_size: Size of uncertain region Returns: Trimap with 0 (background), 128 (unknown), 255 (foreground) """ dilation_size = dilation_size or self.config.trimap_size # Ensure binary mask if mask.dtype != np.uint8: mask = (mask * 255).astype(np.uint8) # Create trimap trimap = np.copy(mask) kernel = cv2.getStructuringElement( cv2.MORPH_ELLIPSE, (dilation_size, dilation_size) ) # Dilate and erode to create unknown region dilated = cv2.dilate(mask, kernel, iterations=1) eroded = cv2.erode(mask, kernel, iterations=1) # Set unknown region trimap[dilated == 255] = 128 trimap[eroded == 255] = 255 return trimap def guided_filter(self, image: np.ndarray, guide: np.ndarray, radius: int = None, eps: float = None) -> np.ndarray: """ Apply guided filter for edge-preserving smoothing. Args: image: Input image to filter guide: Guide image (usually RGB image) radius: Filter radius eps: Regularization parameter Returns: Filtered image """ radius = radius or self.config.guided_filter_radius eps = eps or self.config.guided_filter_eps if len(guide.shape) == 3: guide = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) # Convert to float32 guide = guide.astype(np.float32) / 255.0 image = image.astype(np.float32) / 255.0 # Box filter helper def box_filter(img, r): return cv2.boxFilter(img, -1, (r, r)) # Guided filter implementation mean_I = box_filter(guide, radius) mean_p = box_filter(image, radius) mean_Ip = box_filter(guide * image, radius) cov_Ip = mean_Ip - mean_I * mean_p mean_II = box_filter(guide * guide, radius) var_I = mean_II - mean_I * mean_I a = cov_Ip / (var_I + eps) b = mean_p - a * mean_I mean_a = box_filter(a, radius) mean_b = box_filter(b, radius) output = mean_a * guide + mean_b return np.clip(output * 255, 0, 255).astype(np.uint8) def closed_form_matting(self, image: np.ndarray, trimap: np.ndarray) -> np.ndarray: """ Closed-form matting using Laplacian matrix. Simplified version for real-time processing. Args: image: RGB image trimap: Trimap with known regions Returns: Alpha matte """ h, w = trimap.shape # Initialize alpha with trimap alpha = np.copy(trimap).astype(np.float32) / 255.0 # Known regions is_fg = trimap == 255 is_bg = trimap == 0 is_unknown = trimap == 128 if not np.any(is_unknown): return alpha # Simple propagation from known to unknown regions # Using distance transform for efficiency dist_fg = cv2.distanceTransform( is_fg.astype(np.uint8), cv2.DIST_L2, 5 ) dist_bg = cv2.distanceTransform( is_bg.astype(np.uint8), cv2.DIST_L2, 5 ) # Normalize distances total_dist = dist_fg + dist_bg + 1e-10 alpha_unknown = dist_fg / total_dist # Apply only to unknown regions alpha[is_unknown] = alpha_unknown[is_unknown] # Apply guided filter for smoothing if self.config.use_guided_filter: alpha = self.guided_filter( (alpha * 255).astype(np.uint8), image ) / 255.0 return np.clip(alpha, 0, 1) def deep_matting(self, image: np.ndarray, mask: np.ndarray, model: Optional[nn.Module] = None) -> np.ndarray: """ Apply deep learning-based matting refinement. Args: image: RGB image mask: Initial mask model: Optional pre-trained model Returns: Refined alpha matte """ device = self.device_manager.get_device() # Prepare input h, w = image.shape[:2] # Resize for model input input_size = (512, 512) image_resized = cv2.resize(image, input_size) mask_resized = cv2.resize(mask, input_size) # Convert to tensor image_tensor = torch.from_numpy( image_resized.transpose(2, 0, 1) ).float().unsqueeze(0) / 255.0 mask_tensor = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0 # Move to device image_tensor = image_tensor.to(device) mask_tensor = mask_tensor.to(device) # If no model provided, use simple refinement if model is None: # Simple CNN-based refinement with torch.no_grad(): # Concatenate image and mask x = torch.cat([image_tensor, mask_tensor], dim=1) # Simple refinement network simulation refined = self._simple_refine_network(x) # Convert back to numpy alpha = refined.squeeze().cpu().numpy() else: with torch.no_grad(): alpha = model(image_tensor, mask_tensor) alpha = alpha.squeeze().cpu().numpy() # Resize back to original size alpha = cv2.resize(alpha, (w, h)) return np.clip(alpha, 0, 1) def _simple_refine_network(self, x: torch.Tensor) -> torch.Tensor: """Simple refinement network for demonstration.""" # Extract mask channel mask = x[:, 3:4, :, :] # Apply series of filters refined = mask # Edge-aware smoothing for _ in range(3): refined = F.avg_pool2d(refined, 3, stride=1, padding=1) refined = torch.sigmoid((refined - 0.5) * 10) return refined def morphological_refinement(self, alpha: np.ndarray) -> np.ndarray: """ Apply morphological operations for refinement. Args: alpha: Alpha matte Returns: Refined alpha matte """ # Convert to uint8 alpha_uint8 = (alpha * 255).astype(np.uint8) # Morphological operations kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) # Remove small holes alpha_uint8 = cv2.morphologyEx( alpha_uint8, cv2.MORPH_CLOSE, kernel, iterations=self.config.erode_iterations ) # Remove small components alpha_uint8 = cv2.morphologyEx( alpha_uint8, cv2.MORPH_OPEN, kernel, iterations=self.config.dilate_iterations ) # Smooth boundaries if self.config.blur_radius > 0: alpha_uint8 = cv2.GaussianBlur( alpha_uint8, (self.config.blur_radius * 2 + 1, self.config.blur_radius * 2 + 1), 0 ) return alpha_uint8.astype(np.float32) / 255.0 def process(self, image: np.ndarray, mask: np.ndarray, method: str = 'auto') -> Dict[str, np.ndarray]: """ Process image with selected matting method. Args: image: RGB image mask: Initial segmentation mask method: Matting method ('auto', 'trimap', 'deep', 'guided') Returns: Dictionary with alpha matte and confidence """ try: # Analyze quality quality_metrics = self.quality_analyzer.analyze_frame(image) # Select method based on quality if method == 'auto': if quality_metrics['blur_score'] > 50: method = 'guided' elif quality_metrics['edge_clarity'] > 0.7: method = 'trimap' else: method = 'deep' logger.info(f"Using matting method: {method}") # Apply selected method if method == 'trimap': trimap = self.create_trimap(mask) alpha = self.closed_form_matting(image, trimap) elif method == 'deep': alpha = self.deep_matting(image, mask) elif method == 'guided': alpha = mask.astype(np.float32) / 255.0 if self.config.use_guided_filter: alpha = self.guided_filter( (alpha * 255).astype(np.uint8), image ) / 255.0 else: # Default to simple refinement alpha = mask.astype(np.float32) / 255.0 # Apply morphological refinement alpha = self.morphological_refinement(alpha) # Edge refinement alpha = self.edge_refinement.refine_edges( image, (alpha * 255).astype(np.uint8) ) / 255.0 # Calculate confidence confidence = self._calculate_confidence(alpha, quality_metrics) return { 'alpha': alpha, 'confidence': confidence, 'method_used': method, 'quality_metrics': quality_metrics } except Exception as e: logger.error(f"Matting processing failed: {e}") # Return original mask as fallback return { 'alpha': mask.astype(np.float32) / 255.0, 'confidence': 0.0, 'method_used': 'fallback', 'error': str(e) } def _calculate_confidence(self, alpha: np.ndarray, quality_metrics: Dict) -> float: """Calculate confidence score for the matting result.""" # Base confidence from quality metrics confidence = quality_metrics.get('overall_quality', 0.5) # Adjust based on alpha distribution alpha_mean = np.mean(alpha) alpha_std = np.std(alpha) # Good matting should have clear separation if 0.3 < alpha_mean < 0.7 and alpha_std > 0.3: confidence *= 1.2 # Check for edge clarity edges = cv2.Canny((alpha * 255).astype(np.uint8), 50, 150) edge_ratio = np.sum(edges > 0) / edges.size if edge_ratio < 0.1: # Clear boundaries confidence *= 1.1 return np.clip(confidence, 0.0, 1.0) class CompositingEngine: """Handle alpha compositing and blending.""" def __init__(self): self.logger = setup_logger(f"{__name__}.CompositingEngine") def composite(self, foreground: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray: """ Composite foreground over background using alpha. Args: foreground: Foreground image (H, W, 3) background: Background image (H, W, 3) alpha: Alpha matte (H, W) or (H, W, 1) Returns: Composited image """ # Ensure alpha is 3-channel if len(alpha.shape) == 2: alpha = np.expand_dims(alpha, axis=2) if alpha.shape[2] == 1: alpha = np.repeat(alpha, 3, axis=2) # Ensure float32 fg = foreground.astype(np.float32) / 255.0 bg = background.astype(np.float32) / 255.0 a = alpha.astype(np.float32) if a.max() > 1.0: a = a / 255.0 # Alpha blending result = fg * a + bg * (1 - a) # Convert back to uint8 result = np.clip(result * 255, 0, 255).astype(np.uint8) return result def premultiply_alpha(self, image: np.ndarray, alpha: np.ndarray) -> np.ndarray: """Premultiply image by alpha channel.""" if len(alpha.shape) == 2: alpha = np.expand_dims(alpha, axis=2) result = image.astype(np.float32) * alpha.astype(np.float32) if alpha.max() > 1.0: result = result / 255.0 return np.clip(result, 0, 255).astype(np.uint8)