| | |
| | import torch |
| | import torch.nn.functional as F |
| | from typing import Optional, Dict, Any, Tuple, Union |
| | import numpy as np |
| |
|
| | class MatAnyOneWrapper: |
| | def __init__(self, core, device=None, config=None): |
| | """ |
| | Initialize MatAnyone wrapper with enhanced configuration. |
| | |
| | Args: |
| | core: MatAnyone InferenceCore instance |
| | device: torch device (auto-detect if None) |
| | config: Optional configuration dict for processing parameters |
| | """ |
| | self.core = core |
| | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| | self.config = config or {} |
| | |
| | |
| | self.threshold = self.config.get('threshold', 0.5) |
| | self.edge_refinement = self.config.get('edge_refinement', True) |
| | self.hair_refinement = self.config.get('hair_refinement', True) |
| | |
| | |
| | self.component_weights = self.config.get('component_weights', { |
| | 'base': 1.0, |
| | 'hair': 1.2, |
| | 'edge': 1.5, |
| | 'detail': 1.1 |
| | }) |
| | |
| | |
| | try: |
| | self.core.model.to(self.device) |
| | except Exception: |
| | pass |
| | |
| | try: |
| | self.core.model.eval() |
| | except Exception: |
| | pass |
| | |
| | @torch.inference_mode() |
| | def step(self, |
| | image_tensor: torch.Tensor, |
| | mask_tensor: Optional[torch.Tensor] = None, |
| | objects: Optional[Dict] = None, |
| | first_frame_pred: bool = False, |
| | components: Optional[Dict[str, torch.Tensor]] = None, |
| | **kwargs) -> torch.Tensor: |
| | """ |
| | Process a single frame with optional component masks. |
| | |
| | Args: |
| | image_tensor: (1,3,H,W) float32 [0..1] on self.device |
| | mask_tensor: (1,1,H,W) float32 [0..1] on self.device |
| | objects: Optional object tracking info |
| | first_frame_pred: Whether this is the first frame |
| | components: Optional dict with keys like 'hair', 'edge', 'detail' |
| | Each value is a (1,1,H,W) tensor |
| | **kwargs: Additional arguments for InferenceCore |
| | |
| | Returns: |
| | (1,1,H,W) float32 probabilities in [0..1] |
| | """ |
| | |
| | image_tensor = image_tensor.to(self.device, non_blocking=True) |
| | if mask_tensor is not None: |
| | mask_tensor = mask_tensor.to(self.device, non_blocking=True) |
| | |
| | |
| | if components: |
| | components = { |
| | k: v.to(self.device, non_blocking=True) |
| | for k, v in components.items() |
| | } |
| | |
| | |
| | try: |
| | |
| | out = self.core.step( |
| | image_tensor=image_tensor, |
| | mask_tensor=mask_tensor, |
| | first_frame_pred=first_frame_pred, |
| | objects=objects, |
| | **kwargs |
| | ) |
| | except TypeError: |
| | |
| | out = self.core.step( |
| | frame=image_tensor, |
| | mask=mask_tensor, |
| | **kwargs |
| | ) |
| | |
| | |
| | out = self._normalize_output(out) |
| | |
| | |
| | if components: |
| | out = self._refine_with_components(out, components) |
| | |
| | |
| | if self.edge_refinement and mask_tensor is not None: |
| | out = self._refine_edges(out, image_tensor, mask_tensor) |
| | |
| | return out.clamp_(0, 1) |
| | |
| | def _normalize_output(self, out: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: |
| | """Normalize output to (1,1,H,W) tensor.""" |
| | if isinstance(out, torch.Tensor): |
| | if out.ndim == 3: |
| | out = out.unsqueeze(1) |
| | elif out.ndim == 2: |
| | out = out.unsqueeze(0).unsqueeze(0) |
| | else: |
| | out = torch.as_tensor(out, dtype=torch.float32, device=self.device) |
| | if out.ndim == 2: |
| | out = out.unsqueeze(0).unsqueeze(0) |
| | elif out.ndim == 3: |
| | out = out.unsqueeze(1) |
| | return out |
| | |
| | def _refine_with_components(self, |
| | base_mask: torch.Tensor, |
| | components: Dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Refine mask using component layers (hair, edge, etc). |
| | |
| | Args: |
| | base_mask: (1,1,H,W) base alpha mask |
| | components: Dict of component masks |
| | |
| | Returns: |
| | Refined (1,1,H,W) mask |
| | """ |
| | refined = base_mask.clone() |
| | |
| | |
| | if 'hair' in components and self.hair_refinement: |
| | hair_mask = components['hair'] |
| | weight = self.component_weights.get('hair', 1.0) |
| | |
| | refined = torch.where( |
| | hair_mask > 0.1, |
| | torch.maximum(refined, hair_mask * weight), |
| | refined |
| | ) |
| | |
| | |
| | if 'edge' in components: |
| | edge_mask = components['edge'] |
| | weight = self.component_weights.get('edge', 1.0) |
| | |
| | refined = self._apply_edge_enhancement(refined, edge_mask, weight) |
| | |
| | |
| | if 'detail' in components: |
| | detail_mask = components['detail'] |
| | weight = self.component_weights.get('detail', 1.0) |
| | refined = refined * (1 - detail_mask) + detail_mask * weight |
| | |
| | return refined.clamp_(0, 1) |
| | |
| | def _refine_edges(self, |
| | mask: torch.Tensor, |
| | image: torch.Tensor, |
| | reference_mask: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Apply edge refinement using image gradients. |
| | |
| | Args: |
| | mask: (1,1,H,W) mask to refine |
| | image: (1,3,H,W) source image |
| | reference_mask: (1,1,H,W) reference mask |
| | |
| | Returns: |
| | Edge-refined mask |
| | """ |
| | |
| | gray = image.mean(dim=1, keepdim=True) |
| | |
| | |
| | sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], |
| | dtype=torch.float32, device=self.device) |
| | sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], |
| | dtype=torch.float32, device=self.device) |
| | |
| | sobel_x = sobel_x.view(1, 1, 3, 3) |
| | sobel_y = sobel_y.view(1, 1, 3, 3) |
| | |
| | |
| | edge_x = F.conv2d(gray, sobel_x, padding=1) |
| | edge_y = F.conv2d(gray, sobel_y, padding=1) |
| | edges = torch.sqrt(edge_x**2 + edge_y**2) |
| | |
| | |
| | edges = edges / (edges.max() + 1e-7) |
| | |
| | |
| | kernel_size = 3 |
| | refined = F.avg_pool2d(mask, kernel_size, stride=1, padding=1) |
| | |
| | |
| | alpha = 1 - edges * 0.5 |
| | refined = mask * alpha + refined * (1 - alpha) |
| | |
| | return refined.clamp_(0, 1) |
| | |
| | def _apply_edge_enhancement(self, |
| | mask: torch.Tensor, |
| | edge_mask: torch.Tensor, |
| | weight: float) -> torch.Tensor: |
| | """Apply edge enhancement using edge mask.""" |
| | |
| | kernel = torch.ones(1, 1, 3, 3, device=self.device) / 9 |
| | dilated_edges = F.conv2d(edge_mask, kernel, padding=1) |
| | |
| | |
| | enhanced = torch.where( |
| | dilated_edges > 0.1, |
| | torch.maximum(mask, dilated_edges * weight), |
| | mask |
| | ) |
| | |
| | return enhanced |
| | |
| | def process_batch(self, |
| | images: torch.Tensor, |
| | masks: Optional[torch.Tensor] = None, |
| | components_batch: Optional[Dict[str, torch.Tensor]] = None, |
| | **kwargs) -> torch.Tensor: |
| | """ |
| | Process a batch of frames. |
| | |
| | Args: |
| | images: (B,3,H,W) batch of images |
| | masks: Optional (B,1,H,W) batch of masks |
| | components_batch: Optional dict of component batches |
| | **kwargs: Additional arguments |
| | |
| | Returns: |
| | (B,1,H,W) batch of refined masks |
| | """ |
| | batch_size = images.shape[0] |
| | results = [] |
| | |
| | for i in range(batch_size): |
| | image = images[i:i+1] |
| | mask = masks[i:i+1] if masks is not None else None |
| | |
| | |
| | components = None |
| | if components_batch: |
| | components = { |
| | k: v[i:i+1] for k, v in components_batch.items() |
| | } |
| | |
| | |
| | result = self.step( |
| | image, |
| | mask, |
| | components=components, |
| | first_frame_pred=(i == 0), |
| | **kwargs |
| | ) |
| | results.append(result) |
| | |
| | return torch.cat(results, dim=0) |
| | |
| | def output_prob_to_mask(self, prob: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: |
| | """Convert probability map to binary mask.""" |
| | if isinstance(prob, torch.Tensor): |
| | return (prob > self.threshold).float() |
| | t = torch.as_tensor(prob, device=self.device) |
| | return (t > self.threshold).float() |
| | |
| | def apply_morphology(self, |
| | mask: torch.Tensor, |
| | operation: str = 'close', |
| | kernel_size: int = 5) -> torch.Tensor: |
| | """ |
| | Apply morphological operations to clean up mask. |
| | |
| | Args: |
| | mask: Binary mask tensor |
| | operation: 'close', 'open', 'dilate', or 'erode' |
| | kernel_size: Size of morphological kernel |
| | |
| | Returns: |
| | Processed mask |
| | """ |
| | kernel = torch.ones(1, 1, kernel_size, kernel_size, device=self.device) |
| | |
| | if operation in ['close', 'dilate']: |
| | |
| | mask = F.conv2d(mask, kernel, padding=kernel_size//2) |
| | mask = (mask > 0).float() |
| | |
| | if operation in ['close', 'erode']: |
| | |
| | mask = F.conv2d(mask, kernel, padding=kernel_size//2) |
| | mask = (mask >= kernel_size**2).float() |
| | |
| | if operation == 'open': |
| | |
| | mask = F.conv2d(mask, kernel, padding=kernel_size//2) |
| | mask = (mask >= kernel_size**2).float() |
| | mask = F.conv2d(mask, kernel, padding=kernel_size//2) |
| | mask = (mask > 0).float() |
| | |
| | return mask |
| | |
| | def get_alpha_matte(self, |
| | image: torch.Tensor, |
| | mask: torch.Tensor, |
| | trimap: Optional[torch.Tensor] = None) -> torch.Tensor: |
| | """ |
| | Get alpha matte with optional trimap refinement. |
| | |
| | Args: |
| | image: (1,3,H,W) RGB image |
| | mask: (1,1,H,W) initial mask |
| | trimap: Optional (1,1,H,W) trimap (0=bg, 0.5=unknown, 1=fg) |
| | |
| | Returns: |
| | (1,1,H,W) refined alpha matte |
| | """ |
| | |
| | alpha = self.step(image, mask) |
| | |
| | |
| | if trimap is not None: |
| | alpha = torch.where(trimap == 0, torch.zeros_like(alpha), alpha) |
| | alpha = torch.where(trimap == 1, torch.ones_like(alpha), alpha) |
| | |
| | return alpha |
| | |
| | def composite(self, |
| | foreground: torch.Tensor, |
| | background: torch.Tensor, |
| | alpha: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Composite foreground over background using alpha. |
| | |
| | Args: |
| | foreground: (1,3,H,W) foreground image |
| | background: (1,3,H,W) background image |
| | alpha: (1,1,H,W) alpha matte |
| | |
| | Returns: |
| | (1,3,H,W) composited image |
| | """ |
| | return foreground * alpha + background * (1 - alpha) |