""" Model management and optimization for BackgroundFX Pro. Fixes MatAnyone quality issues and manages model loading. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Any, Optional, Tuple, List from dataclasses import dataclass import numpy as np from pathlib import Path import logging import gc from functools import lru_cache import warnings logger = logging.getLogger(__name__) @dataclass class ModelConfig: """Configuration for model management.""" sam2_checkpoint: str = "checkpoints/sam2_hiera_large.pt" matanyone_checkpoint: str = "checkpoints/matanyone_v2.pth" device: str = "cuda" dtype: torch.dtype = torch.float16 optimize_memory: bool = True use_amp: bool = True cache_size: int = 5 enable_quality_fixes: bool = True matanyone_enhancement: bool = True use_tensorrt: bool = False batch_size: int = 1 class ModelCache: """Intelligent model caching system.""" def __init__(self, max_size: int = 5): self.cache = {} self.max_size = max_size self.access_count = {} self.memory_usage = {} def add(self, key: str, model: Any, memory_size: float): """Add model to cache with memory tracking.""" if len(self.cache) >= self.max_size: # Remove least recently used lru_key = min(self.access_count, key=self.access_count.get) self.remove(lru_key) self.cache[key] = model self.access_count[key] = 0 self.memory_usage[key] = memory_size def get(self, key: str) -> Optional[Any]: """Get model from cache.""" if key in self.cache: self.access_count[key] += 1 return self.cache[key] return None def remove(self, key: str): """Remove model from cache and free memory.""" if key in self.cache: model = self.cache[key] del self.cache[key] del self.access_count[key] del self.memory_usage[key] # Force cleanup del model gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def clear(self): """Clear entire cache.""" keys = list(self.cache.keys()) for key in keys: self.remove(key) class MatAnyoneModel(nn.Module): """Enhanced MatAnyone model with quality fixes.""" def __init__(self, config: ModelConfig): super().__init__() self.config = config self.base_model = None self.quality_enhancer = QualityEnhancer() if config.enable_quality_fixes else None self.loaded = False def load(self): """Load MatAnyone model with optimizations.""" if self.loaded: return try: # Load checkpoint checkpoint_path = Path(self.config.matanyone_checkpoint) if not checkpoint_path.exists(): logger.warning(f"MatAnyone checkpoint not found at {checkpoint_path}") return # Load model weights state_dict = torch.load( checkpoint_path, map_location=self.config.device ) # Initialize base model (placeholder - replace with actual MatAnyone architecture) self.base_model = self._build_matanyone_architecture() # Load weights with compatibility fixes self._load_weights_safe(state_dict) # Optimize model if self.config.optimize_memory: self._optimize_model() self.loaded = True logger.info("MatAnyone model loaded successfully") except Exception as e: logger.error(f"Failed to load MatAnyone model: {e}") self.loaded = False def _build_matanyone_architecture(self) -> nn.Module: """Build MatAnyone architecture.""" # This is a placeholder - replace with actual MatAnyone architecture class MatAnyoneBase(nn.Module): def __init__(self): super().__init__() self.encoder = nn.Sequential( nn.Conv2d(4, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(), ) self.decoder = nn.Sequential( nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(), nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(), nn.Conv2d(64, 4, 3, padding=1), nn.Sigmoid() ) def forward(self, x): features = self.encoder(x) output = self.decoder(features) return output return MatAnyoneBase().to(self.config.device) def _load_weights_safe(self, state_dict: Dict): """Safely load weights with compatibility handling.""" model_dict = self.base_model.state_dict() # Filter compatible weights compatible_dict = {} for k, v in state_dict.items(): # Remove module prefix if present if k.startswith('module.'): k = k[7:] if k in model_dict and model_dict[k].shape == v.shape: compatible_dict[k] = v else: logger.warning(f"Skipping incompatible weight: {k}") # Load compatible weights model_dict.update(compatible_dict) self.base_model.load_state_dict(model_dict, strict=False) logger.info(f"Loaded {len(compatible_dict)}/{len(state_dict)} weights") def _optimize_model(self): """Optimize model for inference.""" if not self.base_model: return self.base_model.eval() # Convert to half precision if using GPU if self.config.dtype == torch.float16 and self.config.device != "cpu": self.base_model = self.base_model.half() # Disable gradient computation for param in self.base_model.parameters(): param.requires_grad = False # TensorRT optimization (if available) if self.config.use_tensorrt: try: self._optimize_with_tensorrt() except Exception as e: logger.warning(f"TensorRT optimization failed: {e}") def forward(self, image: torch.Tensor, mask: torch.Tensor) -> Dict[str, torch.Tensor]: """Enhanced forward pass with quality fixes.""" if not self.loaded: self.load() if not self.base_model: return {'alpha': mask, 'foreground': image} # Prepare input x = torch.cat([image, mask.unsqueeze(1)], dim=1) # Fix input quality issues if self.config.matanyone_enhancement: x = self._preprocess_input(x) # Forward pass with mixed precision with torch.cuda.amp.autocast(enabled=self.config.use_amp): output = self.base_model(x) # Parse output alpha = output[:, 3:4, :, :] foreground = output[:, :3, :, :] # Apply quality enhancement if self.quality_enhancer: alpha = self.quality_enhancer.enhance_alpha(alpha, mask) foreground = self.quality_enhancer.enhance_foreground(foreground, image) # Post-process to fix common MatAnyone issues alpha = self._fix_matanyone_artifacts(alpha, mask) return { 'alpha': alpha, 'foreground': foreground, 'confidence': self._compute_confidence(alpha, mask) } def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor: """Preprocess input to improve MatAnyone quality.""" # Denoise input if x.shape[2] > 64: # Only for reasonable resolutions x = self._bilateral_filter_torch(x) # Normalize properly x = torch.clamp(x, 0, 1) # Enhance edges in mask channel mask_channel = x[:, 3:4, :, :] mask_enhanced = self._enhance_mask_edges(mask_channel) x = torch.cat([x[:, :3, :, :], mask_enhanced], dim=1) return x def _fix_matanyone_artifacts(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor: """Fix common MatAnyone artifacts.""" # Fix edge bleeding alpha = self._fix_edge_bleeding(alpha, original_mask) # Fix transparency issues alpha = self._fix_transparency_issues(alpha) # Ensure consistency with original mask alpha = self._ensure_mask_consistency(alpha, original_mask) return alpha def _fix_edge_bleeding(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor: """Fix edge bleeding artifacts.""" # Detect edges edges = self._detect_edges_torch(original_mask) # Create edge mask edge_mask = F.max_pool2d(edges, kernel_size=5, stride=1, padding=2) # Refine alpha near edges alpha_refined = alpha.clone() edge_region = edge_mask > 0.1 # Apply guided filter near edges if edge_region.any(): alpha_refined[edge_region] = ( 0.7 * alpha[edge_region] + 0.3 * original_mask.unsqueeze(1).expand_as(alpha)[edge_region] ) return alpha_refined def _fix_transparency_issues(self, alpha: torch.Tensor) -> torch.Tensor: """Fix transparency artifacts.""" # Identify problematic transparency values mid_range = (alpha > 0.2) & (alpha < 0.8) # Push mid-range values toward 0 or 1 alpha_fixed = alpha.clone() alpha_fixed[mid_range] = torch.where( alpha[mid_range] > 0.5, torch.clamp(alpha[mid_range] * 1.2, max=1.0), torch.clamp(alpha[mid_range] * 0.8, min=0.0) ) # Smooth transitions alpha_fixed = F.gaussian_blur(alpha_fixed, kernel_size=(3, 3)) return alpha_fixed def _ensure_mask_consistency(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor: """Ensure consistency with original mask.""" # Expand mask dimensions if needed if original_mask.dim() == 2: original_mask = original_mask.unsqueeze(0).unsqueeze(0) elif original_mask.dim() == 3: original_mask = original_mask.unsqueeze(1) # Where original mask is 0, alpha should also be 0 alpha = torch.where(original_mask < 0.1, torch.zeros_like(alpha), alpha) # Where original mask is 1, alpha should be close to 1 alpha = torch.where(original_mask > 0.9, torch.ones_like(alpha) * 0.95, alpha) return alpha def _compute_confidence(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor: """Compute confidence score for the output.""" # Expand dimensions if needed if original_mask.dim() < alpha.dim(): original_mask = original_mask.unsqueeze(1).expand_as(alpha) # Compute similarity diff = torch.abs(alpha - original_mask) confidence = 1.0 - torch.mean(diff, dim=(1, 2, 3)) return confidence def _bilateral_filter_torch(self, x: torch.Tensor) -> torch.Tensor: """Apply bilateral filter in PyTorch.""" # Simple approximation using Gaussian blur # For true bilateral filtering, would need custom CUDA kernel return F.gaussian_blur(x, kernel_size=(5, 5)) def _enhance_mask_edges(self, mask: torch.Tensor) -> torch.Tensor: """Enhance edges in mask channel.""" # Detect edges edges = self._detect_edges_torch(mask) # Enhance mask with edges enhanced = mask + 0.3 * edges enhanced = torch.clamp(enhanced, 0, 1) return enhanced def _detect_edges_torch(self, x: torch.Tensor) -> torch.Tensor: """Detect edges using Sobel filters.""" # Sobel kernels sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=x.dtype, device=x.device).view(1, 1, 3, 3) sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=x.dtype, device=x.device).view(1, 1, 3, 3) # Apply Sobel filters edges_x = F.conv2d(x, sobel_x, padding=1) edges_y = F.conv2d(x, sobel_y, padding=1) # Compute edge magnitude edges = torch.sqrt(edges_x ** 2 + edges_y ** 2) return edges class SAM2Model: """SAM2 model wrapper with optimizations.""" def __init__(self, config: ModelConfig): self.config = config self.model = None self.predictor = None self.loaded = False def load(self): """Load SAM2 model.""" if self.loaded: return try: # Import SAM2 (assuming it's installed) from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor # Build model self.model = build_sam2( config_file="sam2_hiera_l.yaml", ckpt_path=self.config.sam2_checkpoint, device=self.config.device ) # Create predictor self.predictor = SAM2ImagePredictor(self.model) self.loaded = True logger.info("SAM2 model loaded successfully") except Exception as e: logger.error(f"Failed to load SAM2 model: {e}") self.loaded = False def predict(self, image: np.ndarray, prompts: Optional[Dict] = None) -> np.ndarray: """Generate segmentation mask.""" if not self.loaded: self.load() if not self.predictor: return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) # Set image self.predictor.set_image(image) # Use prompts if provided, otherwise use automatic segmentation if prompts: masks, scores, _ = self.predictor.predict( point_coords=prompts.get('points'), point_labels=prompts.get('labels'), box=prompts.get('box'), multimask_output=True ) # Select best mask mask = masks[np.argmax(scores)] else: # Automatic segmentation masks = self.predictor.generate_auto_masks(image) mask = masks[0] if len(masks) > 0 else np.zeros_like(image[:, :, 0]) return mask class QualityEnhancer(nn.Module): """Neural quality enhancement module.""" def __init__(self): super().__init__() self.alpha_refiner = nn.Sequential( nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(), nn.Conv2d(16, 1, 3, padding=1), nn.Sigmoid() ) self.foreground_enhancer = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(), nn.Conv2d(32, 3, 3, padding=1), nn.Tanh() ) def enhance_alpha(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor: """Enhance alpha channel quality.""" # Refine with neural network refined = self.alpha_refiner(alpha) # Blend with original for stability enhanced = 0.7 * refined + 0.3 * alpha return torch.clamp(enhanced, 0, 1) def enhance_foreground(self, foreground: torch.Tensor, original_image: torch.Tensor) -> torch.Tensor: """Enhance foreground quality.""" # Compute residual residual = self.foreground_enhancer(foreground) # Add residual enhanced = foreground + 0.1 * residual return torch.clamp(enhanced, 0, 1) class ModelManager: """Central model management system.""" def __init__(self, config: Optional[ModelConfig] = None): self.config = config or ModelConfig() self.cache = ModelCache(max_size=self.config.cache_size) self.models = {} # Initialize models self.sam2 = SAM2Model(self.config) self.matanyone = MatAnyoneModel(self.config) def load_all(self): """Load all models.""" logger.info("Loading all models...") self.sam2.load() self.matanyone.load() logger.info("All models loaded") def get_sam2(self) -> SAM2Model: """Get SAM2 model.""" if not self.sam2.loaded: self.sam2.load() return self.sam2 def get_matanyone(self) -> MatAnyoneModel: """Get MatAnyone model.""" if not self.matanyone.loaded: self.matanyone.load() return self.matanyone def process_frame(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]: """Process single frame through pipeline.""" # Convert to tensor image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0 image_tensor = image_tensor.to(self.config.device) # Get or generate mask if mask is None: mask = self.sam2.predict(image) mask_tensor = torch.from_numpy(mask).float().to(self.config.device) # Process with MatAnyone result = self.matanyone(image_tensor, mask_tensor) # Convert back to numpy output = { 'alpha': result['alpha'].squeeze().cpu().numpy(), 'foreground': result['foreground'].squeeze().permute(1, 2, 0).cpu().numpy() * 255, 'confidence': result['confidence'].cpu().numpy() } return output def cleanup(self): """Cleanup models and free memory.""" self.cache.clear() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Export classes __all__ = [ 'ModelManager', 'SAM2Model', 'MatAnyoneModel', 'ModelConfig', 'ModelCache', 'QualityEnhancer' ]