| | """ |
| | Fixed MatAnyone Inference Core |
| | Removes tensor-to-numpy conversion bugs that cause F.pad() errors |
| | """ |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from typing import Optional, Union, Tuple |
| |
|
| |
|
| | def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]: |
| | """ |
| | FIXED VERSION: Ensures tensor input stays as tensor |
| | """ |
| | if not isinstance(in_tensor, torch.Tensor): |
| | raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}") |
| | |
| | h, w = in_tensor.shape[-2:] |
| | |
| | |
| | new_h = (h + d - 1) // d * d |
| | new_w = (w + d - 1) // d * d |
| | |
| | lh, uh = (new_h - h) // 2, (new_h - h) // 2 + (new_h - h) % 2 |
| | lw, uw = (new_w - w) // 2, (new_w - w) // 2 + (new_w - w) % 2 |
| | |
| | pad_array = (lw, uw, lh, uh) |
| | |
| | |
| | out = F.pad(in_tensor, pad_array, mode='reflect') |
| | |
| | return out, pad_array |
| |
|
| |
|
| | def unpad_tensor(in_tensor: torch.Tensor, pad: Tuple[int, int, int, int]) -> torch.Tensor: |
| | """Remove padding from tensor""" |
| | if not isinstance(in_tensor, torch.Tensor): |
| | raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}") |
| | |
| | lw, uw, lh, uh = pad |
| | h, w = in_tensor.shape[-2:] |
| | |
| | |
| | if lh > 0: |
| | in_tensor = in_tensor[..., lh:, :] |
| | if uh > 0: |
| | in_tensor = in_tensor[..., :-uh, :] |
| | if lw > 0: |
| | in_tensor = in_tensor[..., :, lw:] |
| | if uw > 0: |
| | in_tensor = in_tensor[..., :, :-uw] |
| | |
| | return in_tensor |
| |
|
| |
|
| | class InferenceCore: |
| | """ |
| | FIXED MatAnyone Inference Core |
| | Handles video matting with proper tensor operations |
| | """ |
| | |
| | def __init__(self, model: torch.nn.Module): |
| | self.model = model |
| | self.model.eval() |
| | self.device = next(model.parameters()).device |
| | self.pad = None |
| | |
| | |
| | self.image_feature_store = {} |
| | self.frame_count = 0 |
| | |
| | def _ensure_tensor_format(self, |
| | image: Union[torch.Tensor, np.ndarray], |
| | prob: Optional[Union[torch.Tensor, np.ndarray]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | """ |
| | CRITICAL FIX: Ensure all inputs are properly formatted tensors |
| | """ |
| | |
| | if isinstance(image, np.ndarray): |
| | if image.ndim == 3 and image.shape[-1] == 3: |
| | image = torch.from_numpy(image.transpose(2, 0, 1)).float() |
| | elif image.ndim == 3 and image.shape[0] == 3: |
| | image = torch.from_numpy(image).float() |
| | else: |
| | raise ValueError(f"Unexpected image shape: {image.shape}") |
| | |
| | |
| | if not isinstance(image, torch.Tensor): |
| | raise TypeError(f"Image must be tensor after conversion, got {type(image)}") |
| | |
| | image = image.float().to(self.device) |
| | |
| | |
| | if image.ndim == 3 and image.shape[0] == 3: |
| | pass |
| | elif image.ndim == 4 and image.shape[0] == 1 and image.shape[1] == 3: |
| | image = image.squeeze(0) |
| | else: |
| | raise ValueError(f"Image must be (3,H,W) or (1,3,H,W), got {image.shape}") |
| | |
| | |
| | if prob is not None: |
| | if isinstance(prob, np.ndarray): |
| | prob = torch.from_numpy(prob).float() |
| | |
| | if not isinstance(prob, torch.Tensor): |
| | raise TypeError(f"Prob must be tensor after conversion, got {type(prob)}") |
| | |
| | prob = prob.float().to(self.device) |
| | |
| | |
| | while prob.ndim > 2: |
| | prob = prob.squeeze(0) |
| | |
| | if prob.ndim != 2: |
| | raise ValueError(f"Prob must be (H,W) after processing, got {prob.shape}") |
| | |
| | return image, prob |
| | |
| | def step(self, |
| | image: Union[torch.Tensor, np.ndarray], |
| | prob: Optional[Union[torch.Tensor, np.ndarray]] = None, |
| | **kwargs) -> torch.Tensor: |
| | """ |
| | FIXED step method with proper tensor handling |
| | """ |
| | |
| | image, prob = self._ensure_tensor_format(image, prob) |
| | |
| | with torch.no_grad(): |
| | |
| | image_padded, self.pad = pad_divide_by(image, 16) |
| | |
| | |
| | image_batch = image_padded.unsqueeze(0) |
| | |
| | if prob is not None: |
| | |
| | h_pad, w_pad = image_padded.shape[-2:] |
| | h_orig, w_orig = prob.shape |
| | |
| | |
| | prob_resized = F.interpolate( |
| | prob.unsqueeze(0).unsqueeze(0), |
| | size=(h_pad, w_pad), |
| | mode='bilinear', |
| | align_corners=False |
| | ).squeeze() |
| | |
| | prob_batch = prob_resized.unsqueeze(0).unsqueeze(0) |
| | |
| | |
| | try: |
| | if hasattr(self.model, 'forward_with_prob'): |
| | output = self.model.forward_with_prob(image_batch, prob_batch) |
| | else: |
| | |
| | input_tensor = torch.cat([image_batch, prob_batch], dim=1) |
| | output = self.model(input_tensor) |
| | except Exception: |
| | |
| | output = self.model(image_batch) |
| | else: |
| | |
| | output = self.model(image_batch) |
| | |
| | |
| | if output.shape[1] == 1: |
| | alpha = output.squeeze(1) |
| | elif output.shape[1] > 1: |
| | alpha = output[:, -1:, :, :] |
| | else: |
| | raise ValueError(f"Unexpected model output shape: {output.shape}") |
| | |
| | |
| | alpha_unpadded = unpad_tensor(alpha, self.pad) |
| | |
| | |
| | alpha_final = alpha_unpadded.squeeze(0) |
| | |
| | |
| | alpha_final = torch.clamp(alpha_final, 0.0, 1.0) |
| | |
| | self.frame_count += 1 |
| | |
| | return alpha_final |
| | |
| | def clear_memory(self): |
| | """Clear stored features for memory management""" |
| | self.image_feature_store.clear() |
| | self.frame_count = 0 |