""" Fixed MatAnyone Inference Core Removes tensor-to-numpy conversion bugs that cause F.pad() errors """ import torch import torch.nn.functional as F import numpy as np from typing import Optional, Union, Tuple def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]: """ FIXED VERSION: Ensures tensor input stays as tensor """ if not isinstance(in_tensor, torch.Tensor): raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}") h, w = in_tensor.shape[-2:] # Calculate padding needed new_h = (h + d - 1) // d * d new_w = (w + d - 1) // d * d lh, uh = (new_h - h) // 2, (new_h - h) // 2 + (new_h - h) % 2 lw, uw = (new_w - w) // 2, (new_w - w) // 2 + (new_w - w) % 2 pad_array = (lw, uw, lh, uh) # CRITICAL FIX: Ensure tensor stays as tensor out = F.pad(in_tensor, pad_array, mode='reflect') return out, pad_array def unpad_tensor(in_tensor: torch.Tensor, pad: Tuple[int, int, int, int]) -> torch.Tensor: """Remove padding from tensor""" if not isinstance(in_tensor, torch.Tensor): raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}") lw, uw, lh, uh = pad h, w = in_tensor.shape[-2:] # Remove padding if lh > 0: in_tensor = in_tensor[..., lh:, :] if uh > 0: in_tensor = in_tensor[..., :-uh, :] if lw > 0: in_tensor = in_tensor[..., :, lw:] if uw > 0: in_tensor = in_tensor[..., :, :-uw] return in_tensor class InferenceCore: """ FIXED MatAnyone Inference Core Handles video matting with proper tensor operations """ def __init__(self, model: torch.nn.Module): self.model = model self.model.eval() self.device = next(model.parameters()).device self.pad = None # Memory storage for temporal consistency self.image_feature_store = {} self.frame_count = 0 def _ensure_tensor_format(self, image: Union[torch.Tensor, np.ndarray], prob: Optional[Union[torch.Tensor, np.ndarray]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ CRITICAL FIX: Ensure all inputs are properly formatted tensors """ # Convert image to tensor if needed if isinstance(image, np.ndarray): if image.ndim == 3 and image.shape[-1] == 3: # HWC format image = torch.from_numpy(image.transpose(2, 0, 1)).float() # Convert to CHW elif image.ndim == 3 and image.shape[0] == 3: # CHW format image = torch.from_numpy(image).float() else: raise ValueError(f"Unexpected image shape: {image.shape}") # Ensure image is on correct device and has correct format if not isinstance(image, torch.Tensor): raise TypeError(f"Image must be tensor after conversion, got {type(image)}") image = image.float().to(self.device) # Ensure CHW format (3, H, W) if image.ndim == 3 and image.shape[0] == 3: pass # Already correct elif image.ndim == 4 and image.shape[0] == 1 and image.shape[1] == 3: image = image.squeeze(0) # Remove batch dimension else: raise ValueError(f"Image must be (3,H,W) or (1,3,H,W), got {image.shape}") # Handle probability mask if provided if prob is not None: if isinstance(prob, np.ndarray): prob = torch.from_numpy(prob).float() if not isinstance(prob, torch.Tensor): raise TypeError(f"Prob must be tensor after conversion, got {type(prob)}") prob = prob.float().to(self.device) # Ensure HW format for prob while prob.ndim > 2: prob = prob.squeeze(0) if prob.ndim != 2: raise ValueError(f"Prob must be (H,W) after processing, got {prob.shape}") return image, prob def step(self, image: Union[torch.Tensor, np.ndarray], prob: Optional[Union[torch.Tensor, np.ndarray]] = None, **kwargs) -> torch.Tensor: """ FIXED step method with proper tensor handling """ # Convert inputs to proper tensor format image, prob = self._ensure_tensor_format(image, prob) with torch.no_grad(): # Pad image for processing image_padded, self.pad = pad_divide_by(image, 16) # Add batch dimension for model image_batch = image_padded.unsqueeze(0) # (1, 3, H_pad, W_pad) if prob is not None: # Pad probability mask to match image h_pad, w_pad = image_padded.shape[-2:] h_orig, w_orig = prob.shape # Resize prob to match padded image size prob_resized = F.interpolate( prob.unsqueeze(0).unsqueeze(0), # (1, 1, H, W) size=(h_pad, w_pad), mode='bilinear', align_corners=False ).squeeze() # (H_pad, W_pad) prob_batch = prob_resized.unsqueeze(0).unsqueeze(0) # (1, 1, H_pad, W_pad) # Forward pass with probability guidance try: if hasattr(self.model, 'forward_with_prob'): output = self.model.forward_with_prob(image_batch, prob_batch) else: # Fallback: concatenate prob as additional channel input_tensor = torch.cat([image_batch, prob_batch], dim=1) # (1, 4, H_pad, W_pad) output = self.model(input_tensor) except Exception: # Final fallback: just use image output = self.model(image_batch) else: # Forward pass without probability guidance output = self.model(image_batch) # Extract alpha channel (assume model outputs alpha as last channel or single channel) if output.shape[1] == 1: alpha = output.squeeze(1) # (1, H_pad, W_pad) elif output.shape[1] > 1: alpha = output[:, -1:, :, :] # Take last channel as alpha else: raise ValueError(f"Unexpected model output shape: {output.shape}") # Remove padding alpha_unpadded = unpad_tensor(alpha, self.pad) # Remove batch dimension and ensure 2D output alpha_final = alpha_unpadded.squeeze(0) # (H, W) # Ensure values are in [0, 1] range alpha_final = torch.clamp(alpha_final, 0.0, 1.0) self.frame_count += 1 return alpha_final def clear_memory(self): """Clear stored features for memory management""" self.image_feature_store.clear() self.frame_count = 0