MogensR's picture
Create matanyone_fixed/inference/inference_core.py
f5fcafb verified
"""
Fixed MatAnyone Inference Core
Removes tensor-to-numpy conversion bugs that cause F.pad() errors
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional, Union, Tuple
def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]:
"""
FIXED VERSION: Ensures tensor input stays as tensor
"""
if not isinstance(in_tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}")
h, w = in_tensor.shape[-2:]
# Calculate padding needed
new_h = (h + d - 1) // d * d
new_w = (w + d - 1) // d * d
lh, uh = (new_h - h) // 2, (new_h - h) // 2 + (new_h - h) % 2
lw, uw = (new_w - w) // 2, (new_w - w) // 2 + (new_w - w) % 2
pad_array = (lw, uw, lh, uh)
# CRITICAL FIX: Ensure tensor stays as tensor
out = F.pad(in_tensor, pad_array, mode='reflect')
return out, pad_array
def unpad_tensor(in_tensor: torch.Tensor, pad: Tuple[int, int, int, int]) -> torch.Tensor:
"""Remove padding from tensor"""
if not isinstance(in_tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}")
lw, uw, lh, uh = pad
h, w = in_tensor.shape[-2:]
# Remove padding
if lh > 0:
in_tensor = in_tensor[..., lh:, :]
if uh > 0:
in_tensor = in_tensor[..., :-uh, :]
if lw > 0:
in_tensor = in_tensor[..., :, lw:]
if uw > 0:
in_tensor = in_tensor[..., :, :-uw]
return in_tensor
class InferenceCore:
"""
FIXED MatAnyone Inference Core
Handles video matting with proper tensor operations
"""
def __init__(self, model: torch.nn.Module):
self.model = model
self.model.eval()
self.device = next(model.parameters()).device
self.pad = None
# Memory storage for temporal consistency
self.image_feature_store = {}
self.frame_count = 0
def _ensure_tensor_format(self,
image: Union[torch.Tensor, np.ndarray],
prob: Optional[Union[torch.Tensor, np.ndarray]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
CRITICAL FIX: Ensure all inputs are properly formatted tensors
"""
# Convert image to tensor if needed
if isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[-1] == 3: # HWC format
image = torch.from_numpy(image.transpose(2, 0, 1)).float() # Convert to CHW
elif image.ndim == 3 and image.shape[0] == 3: # CHW format
image = torch.from_numpy(image).float()
else:
raise ValueError(f"Unexpected image shape: {image.shape}")
# Ensure image is on correct device and has correct format
if not isinstance(image, torch.Tensor):
raise TypeError(f"Image must be tensor after conversion, got {type(image)}")
image = image.float().to(self.device)
# Ensure CHW format (3, H, W)
if image.ndim == 3 and image.shape[0] == 3:
pass # Already correct
elif image.ndim == 4 and image.shape[0] == 1 and image.shape[1] == 3:
image = image.squeeze(0) # Remove batch dimension
else:
raise ValueError(f"Image must be (3,H,W) or (1,3,H,W), got {image.shape}")
# Handle probability mask if provided
if prob is not None:
if isinstance(prob, np.ndarray):
prob = torch.from_numpy(prob).float()
if not isinstance(prob, torch.Tensor):
raise TypeError(f"Prob must be tensor after conversion, got {type(prob)}")
prob = prob.float().to(self.device)
# Ensure HW format for prob
while prob.ndim > 2:
prob = prob.squeeze(0)
if prob.ndim != 2:
raise ValueError(f"Prob must be (H,W) after processing, got {prob.shape}")
return image, prob
def step(self,
image: Union[torch.Tensor, np.ndarray],
prob: Optional[Union[torch.Tensor, np.ndarray]] = None,
**kwargs) -> torch.Tensor:
"""
FIXED step method with proper tensor handling
"""
# Convert inputs to proper tensor format
image, prob = self._ensure_tensor_format(image, prob)
with torch.no_grad():
# Pad image for processing
image_padded, self.pad = pad_divide_by(image, 16)
# Add batch dimension for model
image_batch = image_padded.unsqueeze(0) # (1, 3, H_pad, W_pad)
if prob is not None:
# Pad probability mask to match image
h_pad, w_pad = image_padded.shape[-2:]
h_orig, w_orig = prob.shape
# Resize prob to match padded image size
prob_resized = F.interpolate(
prob.unsqueeze(0).unsqueeze(0), # (1, 1, H, W)
size=(h_pad, w_pad),
mode='bilinear',
align_corners=False
).squeeze() # (H_pad, W_pad)
prob_batch = prob_resized.unsqueeze(0).unsqueeze(0) # (1, 1, H_pad, W_pad)
# Forward pass with probability guidance
try:
if hasattr(self.model, 'forward_with_prob'):
output = self.model.forward_with_prob(image_batch, prob_batch)
else:
# Fallback: concatenate prob as additional channel
input_tensor = torch.cat([image_batch, prob_batch], dim=1) # (1, 4, H_pad, W_pad)
output = self.model(input_tensor)
except Exception:
# Final fallback: just use image
output = self.model(image_batch)
else:
# Forward pass without probability guidance
output = self.model(image_batch)
# Extract alpha channel (assume model outputs alpha as last channel or single channel)
if output.shape[1] == 1:
alpha = output.squeeze(1) # (1, H_pad, W_pad)
elif output.shape[1] > 1:
alpha = output[:, -1:, :, :] # Take last channel as alpha
else:
raise ValueError(f"Unexpected model output shape: {output.shape}")
# Remove padding
alpha_unpadded = unpad_tensor(alpha, self.pad)
# Remove batch dimension and ensure 2D output
alpha_final = alpha_unpadded.squeeze(0) # (H, W)
# Ensure values are in [0, 1] range
alpha_final = torch.clamp(alpha_final, 0.0, 1.0)
self.frame_count += 1
return alpha_final
def clear_memory(self):
"""Clear stored features for memory management"""
self.image_feature_store.clear()
self.frame_count = 0