| | """ |
| | Model management and optimization for BackgroundFX Pro. |
| | Fixes MatAnyone quality issues and manages model loading. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Dict, Any, Optional, Tuple, List |
| | from dataclasses import dataclass |
| | import numpy as np |
| | from pathlib import Path |
| | import logging |
| | import gc |
| | from functools import lru_cache |
| | import warnings |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class ModelConfig: |
| | """Configuration for model management.""" |
| | sam2_checkpoint: str = "checkpoints/sam2_hiera_large.pt" |
| | matanyone_checkpoint: str = "checkpoints/matanyone_v2.pth" |
| | device: str = "cuda" |
| | dtype: torch.dtype = torch.float16 |
| | optimize_memory: bool = True |
| | use_amp: bool = True |
| | cache_size: int = 5 |
| | enable_quality_fixes: bool = True |
| | matanyone_enhancement: bool = True |
| | use_tensorrt: bool = False |
| | batch_size: int = 1 |
| |
|
| |
|
| | class ModelCache: |
| | """Intelligent model caching system.""" |
| | |
| | def __init__(self, max_size: int = 5): |
| | self.cache = {} |
| | self.max_size = max_size |
| | self.access_count = {} |
| | self.memory_usage = {} |
| | |
| | def add(self, key: str, model: Any, memory_size: float): |
| | """Add model to cache with memory tracking.""" |
| | if len(self.cache) >= self.max_size: |
| | |
| | lru_key = min(self.access_count, key=self.access_count.get) |
| | self.remove(lru_key) |
| | |
| | self.cache[key] = model |
| | self.access_count[key] = 0 |
| | self.memory_usage[key] = memory_size |
| | |
| | def get(self, key: str) -> Optional[Any]: |
| | """Get model from cache.""" |
| | if key in self.cache: |
| | self.access_count[key] += 1 |
| | return self.cache[key] |
| | return None |
| | |
| | def remove(self, key: str): |
| | """Remove model from cache and free memory.""" |
| | if key in self.cache: |
| | model = self.cache[key] |
| | del self.cache[key] |
| | del self.access_count[key] |
| | del self.memory_usage[key] |
| | |
| | |
| | del model |
| | gc.collect() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | |
| | def clear(self): |
| | """Clear entire cache.""" |
| | keys = list(self.cache.keys()) |
| | for key in keys: |
| | self.remove(key) |
| |
|
| |
|
| | class MatAnyoneModel(nn.Module): |
| | """Enhanced MatAnyone model with quality fixes.""" |
| | |
| | def __init__(self, config: ModelConfig): |
| | super().__init__() |
| | self.config = config |
| | self.base_model = None |
| | self.quality_enhancer = QualityEnhancer() if config.enable_quality_fixes else None |
| | self.loaded = False |
| | |
| | def load(self): |
| | """Load MatAnyone model with optimizations.""" |
| | if self.loaded: |
| | return |
| | |
| | try: |
| | |
| | checkpoint_path = Path(self.config.matanyone_checkpoint) |
| | if not checkpoint_path.exists(): |
| | logger.warning(f"MatAnyone checkpoint not found at {checkpoint_path}") |
| | return |
| | |
| | |
| | state_dict = torch.load( |
| | checkpoint_path, |
| | map_location=self.config.device |
| | ) |
| | |
| | |
| | self.base_model = self._build_matanyone_architecture() |
| | |
| | |
| | self._load_weights_safe(state_dict) |
| | |
| | |
| | if self.config.optimize_memory: |
| | self._optimize_model() |
| | |
| | self.loaded = True |
| | logger.info("MatAnyone model loaded successfully") |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to load MatAnyone model: {e}") |
| | self.loaded = False |
| | |
| | def _build_matanyone_architecture(self) -> nn.Module: |
| | """Build MatAnyone architecture.""" |
| | |
| | class MatAnyoneBase(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.encoder = nn.Sequential( |
| | nn.Conv2d(4, 64, 3, padding=1), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 128, 3, stride=2, padding=1), |
| | nn.ReLU(), |
| | nn.Conv2d(128, 256, 3, stride=2, padding=1), |
| | nn.ReLU(), |
| | ) |
| | self.decoder = nn.Sequential( |
| | nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), |
| | nn.ReLU(), |
| | nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 4, 3, padding=1), |
| | nn.Sigmoid() |
| | ) |
| | |
| | def forward(self, x): |
| | features = self.encoder(x) |
| | output = self.decoder(features) |
| | return output |
| | |
| | return MatAnyoneBase().to(self.config.device) |
| | |
| | def _load_weights_safe(self, state_dict: Dict): |
| | """Safely load weights with compatibility handling.""" |
| | model_dict = self.base_model.state_dict() |
| | |
| | |
| | compatible_dict = {} |
| | for k, v in state_dict.items(): |
| | |
| | if k.startswith('module.'): |
| | k = k[7:] |
| | |
| | if k in model_dict and model_dict[k].shape == v.shape: |
| | compatible_dict[k] = v |
| | else: |
| | logger.warning(f"Skipping incompatible weight: {k}") |
| | |
| | |
| | model_dict.update(compatible_dict) |
| | self.base_model.load_state_dict(model_dict, strict=False) |
| | |
| | logger.info(f"Loaded {len(compatible_dict)}/{len(state_dict)} weights") |
| | |
| | def _optimize_model(self): |
| | """Optimize model for inference.""" |
| | if not self.base_model: |
| | return |
| | |
| | self.base_model.eval() |
| | |
| | |
| | if self.config.dtype == torch.float16 and self.config.device != "cpu": |
| | self.base_model = self.base_model.half() |
| | |
| | |
| | for param in self.base_model.parameters(): |
| | param.requires_grad = False |
| | |
| | |
| | if self.config.use_tensorrt: |
| | try: |
| | self._optimize_with_tensorrt() |
| | except Exception as e: |
| | logger.warning(f"TensorRT optimization failed: {e}") |
| | |
| | def forward(self, image: torch.Tensor, mask: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | """Enhanced forward pass with quality fixes.""" |
| | if not self.loaded: |
| | self.load() |
| | |
| | if not self.base_model: |
| | return {'alpha': mask, 'foreground': image} |
| | |
| | |
| | x = torch.cat([image, mask.unsqueeze(1)], dim=1) |
| | |
| | |
| | if self.config.matanyone_enhancement: |
| | x = self._preprocess_input(x) |
| | |
| | |
| | with torch.cuda.amp.autocast(enabled=self.config.use_amp): |
| | output = self.base_model(x) |
| | |
| | |
| | alpha = output[:, 3:4, :, :] |
| | foreground = output[:, :3, :, :] |
| | |
| | |
| | if self.quality_enhancer: |
| | alpha = self.quality_enhancer.enhance_alpha(alpha, mask) |
| | foreground = self.quality_enhancer.enhance_foreground(foreground, image) |
| | |
| | |
| | alpha = self._fix_matanyone_artifacts(alpha, mask) |
| | |
| | return { |
| | 'alpha': alpha, |
| | 'foreground': foreground, |
| | 'confidence': self._compute_confidence(alpha, mask) |
| | } |
| | |
| | def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor: |
| | """Preprocess input to improve MatAnyone quality.""" |
| | |
| | if x.shape[2] > 64: |
| | x = self._bilateral_filter_torch(x) |
| | |
| | |
| | x = torch.clamp(x, 0, 1) |
| | |
| | |
| | mask_channel = x[:, 3:4, :, :] |
| | mask_enhanced = self._enhance_mask_edges(mask_channel) |
| | x = torch.cat([x[:, :3, :, :], mask_enhanced], dim=1) |
| | |
| | return x |
| | |
| | def _fix_matanyone_artifacts(self, alpha: torch.Tensor, |
| | original_mask: torch.Tensor) -> torch.Tensor: |
| | """Fix common MatAnyone artifacts.""" |
| | |
| | alpha = self._fix_edge_bleeding(alpha, original_mask) |
| | |
| | |
| | alpha = self._fix_transparency_issues(alpha) |
| | |
| | |
| | alpha = self._ensure_mask_consistency(alpha, original_mask) |
| | |
| | return alpha |
| | |
| | def _fix_edge_bleeding(self, alpha: torch.Tensor, |
| | original_mask: torch.Tensor) -> torch.Tensor: |
| | """Fix edge bleeding artifacts.""" |
| | |
| | edges = self._detect_edges_torch(original_mask) |
| | |
| | |
| | edge_mask = F.max_pool2d(edges, kernel_size=5, stride=1, padding=2) |
| | |
| | |
| | alpha_refined = alpha.clone() |
| | edge_region = edge_mask > 0.1 |
| | |
| | |
| | if edge_region.any(): |
| | alpha_refined[edge_region] = ( |
| | 0.7 * alpha[edge_region] + |
| | 0.3 * original_mask.unsqueeze(1).expand_as(alpha)[edge_region] |
| | ) |
| | |
| | return alpha_refined |
| | |
| | def _fix_transparency_issues(self, alpha: torch.Tensor) -> torch.Tensor: |
| | """Fix transparency artifacts.""" |
| | |
| | mid_range = (alpha > 0.2) & (alpha < 0.8) |
| | |
| | |
| | alpha_fixed = alpha.clone() |
| | alpha_fixed[mid_range] = torch.where( |
| | alpha[mid_range] > 0.5, |
| | torch.clamp(alpha[mid_range] * 1.2, max=1.0), |
| | torch.clamp(alpha[mid_range] * 0.8, min=0.0) |
| | ) |
| | |
| | |
| | alpha_fixed = F.gaussian_blur(alpha_fixed, kernel_size=(3, 3)) |
| | |
| | return alpha_fixed |
| | |
| | def _ensure_mask_consistency(self, alpha: torch.Tensor, |
| | original_mask: torch.Tensor) -> torch.Tensor: |
| | """Ensure consistency with original mask.""" |
| | |
| | if original_mask.dim() == 2: |
| | original_mask = original_mask.unsqueeze(0).unsqueeze(0) |
| | elif original_mask.dim() == 3: |
| | original_mask = original_mask.unsqueeze(1) |
| | |
| | |
| | alpha = torch.where(original_mask < 0.1, torch.zeros_like(alpha), alpha) |
| | |
| | |
| | alpha = torch.where(original_mask > 0.9, torch.ones_like(alpha) * 0.95, alpha) |
| | |
| | return alpha |
| | |
| | def _compute_confidence(self, alpha: torch.Tensor, |
| | original_mask: torch.Tensor) -> torch.Tensor: |
| | """Compute confidence score for the output.""" |
| | |
| | if original_mask.dim() < alpha.dim(): |
| | original_mask = original_mask.unsqueeze(1).expand_as(alpha) |
| | |
| | |
| | diff = torch.abs(alpha - original_mask) |
| | confidence = 1.0 - torch.mean(diff, dim=(1, 2, 3)) |
| | |
| | return confidence |
| | |
| | def _bilateral_filter_torch(self, x: torch.Tensor) -> torch.Tensor: |
| | """Apply bilateral filter in PyTorch.""" |
| | |
| | |
| | return F.gaussian_blur(x, kernel_size=(5, 5)) |
| | |
| | def _enhance_mask_edges(self, mask: torch.Tensor) -> torch.Tensor: |
| | """Enhance edges in mask channel.""" |
| | |
| | edges = self._detect_edges_torch(mask) |
| | |
| | |
| | enhanced = mask + 0.3 * edges |
| | enhanced = torch.clamp(enhanced, 0, 1) |
| | |
| | return enhanced |
| | |
| | def _detect_edges_torch(self, x: torch.Tensor) -> torch.Tensor: |
| | """Detect edges using Sobel filters.""" |
| | |
| | sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], |
| | dtype=x.dtype, device=x.device).view(1, 1, 3, 3) |
| | sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], |
| | dtype=x.dtype, device=x.device).view(1, 1, 3, 3) |
| | |
| | |
| | edges_x = F.conv2d(x, sobel_x, padding=1) |
| | edges_y = F.conv2d(x, sobel_y, padding=1) |
| | |
| | |
| | edges = torch.sqrt(edges_x ** 2 + edges_y ** 2) |
| | |
| | return edges |
| |
|
| |
|
| | class SAM2Model: |
| | """SAM2 model wrapper with optimizations.""" |
| | |
| | def __init__(self, config: ModelConfig): |
| | self.config = config |
| | self.model = None |
| | self.predictor = None |
| | self.loaded = False |
| | |
| | def load(self): |
| | """Load SAM2 model.""" |
| | if self.loaded: |
| | return |
| | |
| | try: |
| | |
| | from sam2.build_sam import build_sam2 |
| | from sam2.sam2_image_predictor import SAM2ImagePredictor |
| | |
| | |
| | self.model = build_sam2( |
| | config_file="sam2_hiera_l.yaml", |
| | ckpt_path=self.config.sam2_checkpoint, |
| | device=self.config.device |
| | ) |
| | |
| | |
| | self.predictor = SAM2ImagePredictor(self.model) |
| | |
| | self.loaded = True |
| | logger.info("SAM2 model loaded successfully") |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to load SAM2 model: {e}") |
| | self.loaded = False |
| | |
| | def predict(self, image: np.ndarray, prompts: Optional[Dict] = None) -> np.ndarray: |
| | """Generate segmentation mask.""" |
| | if not self.loaded: |
| | self.load() |
| | |
| | if not self.predictor: |
| | return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) |
| | |
| | |
| | self.predictor.set_image(image) |
| | |
| | |
| | if prompts: |
| | masks, scores, _ = self.predictor.predict( |
| | point_coords=prompts.get('points'), |
| | point_labels=prompts.get('labels'), |
| | box=prompts.get('box'), |
| | multimask_output=True |
| | ) |
| | |
| | mask = masks[np.argmax(scores)] |
| | else: |
| | |
| | masks = self.predictor.generate_auto_masks(image) |
| | mask = masks[0] if len(masks) > 0 else np.zeros_like(image[:, :, 0]) |
| | |
| | return mask |
| |
|
| |
|
| | class QualityEnhancer(nn.Module): |
| | """Neural quality enhancement module.""" |
| | |
| | def __init__(self): |
| | super().__init__() |
| | self.alpha_refiner = nn.Sequential( |
| | nn.Conv2d(1, 16, 3, padding=1), |
| | nn.ReLU(), |
| | nn.Conv2d(16, 16, 3, padding=1), |
| | nn.ReLU(), |
| | nn.Conv2d(16, 1, 3, padding=1), |
| | nn.Sigmoid() |
| | ) |
| | |
| | self.foreground_enhancer = nn.Sequential( |
| | nn.Conv2d(3, 32, 3, padding=1), |
| | nn.ReLU(), |
| | nn.Conv2d(32, 32, 3, padding=1), |
| | nn.ReLU(), |
| | nn.Conv2d(32, 3, 3, padding=1), |
| | nn.Tanh() |
| | ) |
| | |
| | def enhance_alpha(self, alpha: torch.Tensor, |
| | original_mask: torch.Tensor) -> torch.Tensor: |
| | """Enhance alpha channel quality.""" |
| | |
| | refined = self.alpha_refiner(alpha) |
| | |
| | |
| | enhanced = 0.7 * refined + 0.3 * alpha |
| | |
| | return torch.clamp(enhanced, 0, 1) |
| | |
| | def enhance_foreground(self, foreground: torch.Tensor, |
| | original_image: torch.Tensor) -> torch.Tensor: |
| | """Enhance foreground quality.""" |
| | |
| | residual = self.foreground_enhancer(foreground) |
| | |
| | |
| | enhanced = foreground + 0.1 * residual |
| | |
| | return torch.clamp(enhanced, 0, 1) |
| |
|
| |
|
| | class ModelManager: |
| | """Central model management system.""" |
| | |
| | def __init__(self, config: Optional[ModelConfig] = None): |
| | self.config = config or ModelConfig() |
| | self.cache = ModelCache(max_size=self.config.cache_size) |
| | self.models = {} |
| | |
| | |
| | self.sam2 = SAM2Model(self.config) |
| | self.matanyone = MatAnyoneModel(self.config) |
| | |
| | def load_all(self): |
| | """Load all models.""" |
| | logger.info("Loading all models...") |
| | self.sam2.load() |
| | self.matanyone.load() |
| | logger.info("All models loaded") |
| | |
| | def get_sam2(self) -> SAM2Model: |
| | """Get SAM2 model.""" |
| | if not self.sam2.loaded: |
| | self.sam2.load() |
| | return self.sam2 |
| | |
| | def get_matanyone(self) -> MatAnyoneModel: |
| | """Get MatAnyone model.""" |
| | if not self.matanyone.loaded: |
| | self.matanyone.load() |
| | return self.matanyone |
| | |
| | def process_frame(self, image: np.ndarray, |
| | mask: Optional[np.ndarray] = None) -> Dict[str, Any]: |
| | """Process single frame through pipeline.""" |
| | |
| | image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0 |
| | image_tensor = image_tensor.to(self.config.device) |
| | |
| | |
| | if mask is None: |
| | mask = self.sam2.predict(image) |
| | |
| | mask_tensor = torch.from_numpy(mask).float().to(self.config.device) |
| | |
| | |
| | result = self.matanyone(image_tensor, mask_tensor) |
| | |
| | |
| | output = { |
| | 'alpha': result['alpha'].squeeze().cpu().numpy(), |
| | 'foreground': result['foreground'].squeeze().permute(1, 2, 0).cpu().numpy() * 255, |
| | 'confidence': result['confidence'].cpu().numpy() |
| | } |
| | |
| | return output |
| | |
| | def cleanup(self): |
| | """Cleanup models and free memory.""" |
| | self.cache.clear() |
| | gc.collect() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | |
| | __all__ = [ |
| | 'ModelManager', |
| | 'SAM2Model', |
| | 'MatAnyoneModel', |
| | 'ModelConfig', |
| | 'ModelCache', |
| | 'QualityEnhancer' |
| | ] |