File size: 7,276 Bytes
f5fcafb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | """
Fixed MatAnyone Inference Core
Removes tensor-to-numpy conversion bugs that cause F.pad() errors
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional, Union, Tuple
def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]:
"""
FIXED VERSION: Ensures tensor input stays as tensor
"""
if not isinstance(in_tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}")
h, w = in_tensor.shape[-2:]
# Calculate padding needed
new_h = (h + d - 1) // d * d
new_w = (w + d - 1) // d * d
lh, uh = (new_h - h) // 2, (new_h - h) // 2 + (new_h - h) % 2
lw, uw = (new_w - w) // 2, (new_w - w) // 2 + (new_w - w) % 2
pad_array = (lw, uw, lh, uh)
# CRITICAL FIX: Ensure tensor stays as tensor
out = F.pad(in_tensor, pad_array, mode='reflect')
return out, pad_array
def unpad_tensor(in_tensor: torch.Tensor, pad: Tuple[int, int, int, int]) -> torch.Tensor:
"""Remove padding from tensor"""
if not isinstance(in_tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}")
lw, uw, lh, uh = pad
h, w = in_tensor.shape[-2:]
# Remove padding
if lh > 0:
in_tensor = in_tensor[..., lh:, :]
if uh > 0:
in_tensor = in_tensor[..., :-uh, :]
if lw > 0:
in_tensor = in_tensor[..., :, lw:]
if uw > 0:
in_tensor = in_tensor[..., :, :-uw]
return in_tensor
class InferenceCore:
"""
FIXED MatAnyone Inference Core
Handles video matting with proper tensor operations
"""
def __init__(self, model: torch.nn.Module):
self.model = model
self.model.eval()
self.device = next(model.parameters()).device
self.pad = None
# Memory storage for temporal consistency
self.image_feature_store = {}
self.frame_count = 0
def _ensure_tensor_format(self,
image: Union[torch.Tensor, np.ndarray],
prob: Optional[Union[torch.Tensor, np.ndarray]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
CRITICAL FIX: Ensure all inputs are properly formatted tensors
"""
# Convert image to tensor if needed
if isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[-1] == 3: # HWC format
image = torch.from_numpy(image.transpose(2, 0, 1)).float() # Convert to CHW
elif image.ndim == 3 and image.shape[0] == 3: # CHW format
image = torch.from_numpy(image).float()
else:
raise ValueError(f"Unexpected image shape: {image.shape}")
# Ensure image is on correct device and has correct format
if not isinstance(image, torch.Tensor):
raise TypeError(f"Image must be tensor after conversion, got {type(image)}")
image = image.float().to(self.device)
# Ensure CHW format (3, H, W)
if image.ndim == 3 and image.shape[0] == 3:
pass # Already correct
elif image.ndim == 4 and image.shape[0] == 1 and image.shape[1] == 3:
image = image.squeeze(0) # Remove batch dimension
else:
raise ValueError(f"Image must be (3,H,W) or (1,3,H,W), got {image.shape}")
# Handle probability mask if provided
if prob is not None:
if isinstance(prob, np.ndarray):
prob = torch.from_numpy(prob).float()
if not isinstance(prob, torch.Tensor):
raise TypeError(f"Prob must be tensor after conversion, got {type(prob)}")
prob = prob.float().to(self.device)
# Ensure HW format for prob
while prob.ndim > 2:
prob = prob.squeeze(0)
if prob.ndim != 2:
raise ValueError(f"Prob must be (H,W) after processing, got {prob.shape}")
return image, prob
def step(self,
image: Union[torch.Tensor, np.ndarray],
prob: Optional[Union[torch.Tensor, np.ndarray]] = None,
**kwargs) -> torch.Tensor:
"""
FIXED step method with proper tensor handling
"""
# Convert inputs to proper tensor format
image, prob = self._ensure_tensor_format(image, prob)
with torch.no_grad():
# Pad image for processing
image_padded, self.pad = pad_divide_by(image, 16)
# Add batch dimension for model
image_batch = image_padded.unsqueeze(0) # (1, 3, H_pad, W_pad)
if prob is not None:
# Pad probability mask to match image
h_pad, w_pad = image_padded.shape[-2:]
h_orig, w_orig = prob.shape
# Resize prob to match padded image size
prob_resized = F.interpolate(
prob.unsqueeze(0).unsqueeze(0), # (1, 1, H, W)
size=(h_pad, w_pad),
mode='bilinear',
align_corners=False
).squeeze() # (H_pad, W_pad)
prob_batch = prob_resized.unsqueeze(0).unsqueeze(0) # (1, 1, H_pad, W_pad)
# Forward pass with probability guidance
try:
if hasattr(self.model, 'forward_with_prob'):
output = self.model.forward_with_prob(image_batch, prob_batch)
else:
# Fallback: concatenate prob as additional channel
input_tensor = torch.cat([image_batch, prob_batch], dim=1) # (1, 4, H_pad, W_pad)
output = self.model(input_tensor)
except Exception:
# Final fallback: just use image
output = self.model(image_batch)
else:
# Forward pass without probability guidance
output = self.model(image_batch)
# Extract alpha channel (assume model outputs alpha as last channel or single channel)
if output.shape[1] == 1:
alpha = output.squeeze(1) # (1, H_pad, W_pad)
elif output.shape[1] > 1:
alpha = output[:, -1:, :, :] # Take last channel as alpha
else:
raise ValueError(f"Unexpected model output shape: {output.shape}")
# Remove padding
alpha_unpadded = unpad_tensor(alpha, self.pad)
# Remove batch dimension and ensure 2D output
alpha_final = alpha_unpadded.squeeze(0) # (H, W)
# Ensure values are in [0, 1] range
alpha_final = torch.clamp(alpha_final, 0.0, 1.0)
self.frame_count += 1
return alpha_final
def clear_memory(self):
"""Clear stored features for memory management"""
self.image_feature_store.clear()
self.frame_count = 0 |