Update models/loaders/matanyone_loader.py
Browse files
models/loaders/matanyone_loader.py
CHANGED
|
@@ -107,13 +107,50 @@ def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask=Fa
|
|
| 107 |
mode = "nearest" if is_mask else "bilinear"
|
| 108 |
return F.interpolate(x, size=size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
def _to_2d_alpha_numpy(x) -> np.ndarray:
|
| 111 |
t = torch.as_tensor(x).float()
|
| 112 |
while t.ndim > 2:
|
| 113 |
-
if t.ndim ==
|
| 114 |
-
t = t[0
|
|
|
|
|
|
|
| 115 |
else:
|
| 116 |
-
t = t.squeeze()
|
| 117 |
t = t.clamp_(0.0, 1.0)
|
| 118 |
out = t.detach().cpu().numpy().astype(np.float32)
|
| 119 |
return np.ascontiguousarray(out)
|
|
@@ -188,17 +225,22 @@ def _compute_scaled_size(self, h: int, w: int) -> Tuple[int, int, float]:
|
|
| 188 |
return nh, nw, s
|
| 189 |
|
| 190 |
def _to_alpha(self, out_prob):
|
|
|
|
| 191 |
if self._has_prob_to_mask:
|
| 192 |
try:
|
| 193 |
return self.core.output_prob_to_mask(out_prob, matting=True)
|
| 194 |
except Exception:
|
| 195 |
pass
|
| 196 |
t = torch.as_tensor(out_prob).float()
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
# ---- main call ----
|
| 204 |
def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
|
@@ -217,12 +259,20 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
|
| 217 |
# dtype alignment for activations
|
| 218 |
img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
|
| 219 |
|
| 220 |
-
#
|
| 221 |
nh, nw, s = self._compute_scaled_size(H, W)
|
| 222 |
scales = [(nh, nw)]
|
|
|
|
| 223 |
if s < 1.0:
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
last_exc = None
|
| 228 |
|
|
@@ -232,11 +282,9 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
|
| 232 |
img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
|
| 233 |
msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
|
| 234 |
|
| 235 |
-
# ---- IMPORTANT SHAPE CHANGES (only edit) ----
|
| 236 |
img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
|
| 237 |
m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None # [1,H,W] or None
|
| 238 |
-
mask_2d = m_1hw[0].contiguous() if m_1hw is not None else None
|
| 239 |
-
# ------------------------------------------------
|
| 240 |
|
| 241 |
# inference with autocast + inference_mode
|
| 242 |
with torch.inference_mode():
|
|
@@ -268,11 +316,12 @@ def __exit__(self, *args): return False
|
|
| 268 |
out_prob = self.core.step(image=img_chw)
|
| 269 |
alpha = self._to_alpha(out_prob)
|
| 270 |
|
| 271 |
-
#
|
| 272 |
if (th, tw) != (H, W):
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
alpha
|
|
|
|
| 276 |
|
| 277 |
return _to_2d_alpha_numpy(alpha)
|
| 278 |
|
|
|
|
| 107 |
mode = "nearest" if is_mask else "bilinear"
|
| 108 |
return F.interpolate(x, size=size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
|
| 109 |
|
| 110 |
+
def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
|
| 111 |
+
"""
|
| 112 |
+
Convert any plausible alpha/prob output into [1,1,H,W] float in [0,1].
|
| 113 |
+
Prevents 5D/6D mishaps when upsampling.
|
| 114 |
+
"""
|
| 115 |
+
t = torch.as_tensor(alpha, device=device).float()
|
| 116 |
+
if t.ndim == 2:
|
| 117 |
+
t = t.unsqueeze(0).unsqueeze(0) # -> [1,1,H,W]
|
| 118 |
+
elif t.ndim == 3:
|
| 119 |
+
# CHW or 1HW
|
| 120 |
+
if t.shape[0] in (1, 3, 4):
|
| 121 |
+
if t.shape[0] != 1:
|
| 122 |
+
t = t[:1] # keep first channel
|
| 123 |
+
t = t.unsqueeze(0) # -> [1,1,H,W]
|
| 124 |
+
elif t.shape[-1] in (1, 3, 4): # HWC (unexpected, but handle)
|
| 125 |
+
t = t[..., :1].permute(2, 0, 1).unsqueeze(0)
|
| 126 |
+
else:
|
| 127 |
+
# assume [H,W,C?] incompatible → fallback to first dim semantics
|
| 128 |
+
t = t[:1].unsqueeze(0)
|
| 129 |
+
elif t.ndim == 4:
|
| 130 |
+
# [B,C,H,W] → ensure C=1 and B=1
|
| 131 |
+
if t.shape[1] != 1:
|
| 132 |
+
t = t[:, :1]
|
| 133 |
+
if t.shape[0] != 1:
|
| 134 |
+
t = t[:1]
|
| 135 |
+
else:
|
| 136 |
+
# squeeze weird shapes down to [1,1,H,W] best-effort
|
| 137 |
+
while t.ndim > 4:
|
| 138 |
+
t = t.squeeze(0)
|
| 139 |
+
while t.ndim < 4:
|
| 140 |
+
t = t.unsqueeze(0)
|
| 141 |
+
if t.shape[1] != 1:
|
| 142 |
+
t = t[:, :1]
|
| 143 |
+
return t.clamp_(0.0, 1.0).contiguous()
|
| 144 |
+
|
| 145 |
def _to_2d_alpha_numpy(x) -> np.ndarray:
|
| 146 |
t = torch.as_tensor(x).float()
|
| 147 |
while t.ndim > 2:
|
| 148 |
+
if t.ndim == 4 and t.shape[0] == 1 and t.shape[1] == 1:
|
| 149 |
+
t = t[0, 0]
|
| 150 |
+
elif t.ndim == 3 and t.shape[0] == 1:
|
| 151 |
+
t = t[0]
|
| 152 |
else:
|
| 153 |
+
t = t.squeeze(0)
|
| 154 |
t = t.clamp_(0.0, 1.0)
|
| 155 |
out = t.detach().cpu().numpy().astype(np.float32)
|
| 156 |
return np.ascontiguousarray(out)
|
|
|
|
| 225 |
return nh, nw, s
|
| 226 |
|
| 227 |
def _to_alpha(self, out_prob):
|
| 228 |
+
# Prefer library conversion if available
|
| 229 |
if self._has_prob_to_mask:
|
| 230 |
try:
|
| 231 |
return self.core.output_prob_to_mask(out_prob, matting=True)
|
| 232 |
except Exception:
|
| 233 |
pass
|
| 234 |
t = torch.as_tensor(out_prob).float()
|
| 235 |
+
# Normalize common cases to 2-D alpha
|
| 236 |
+
if t.ndim == 4: # [B,C,H,W]
|
| 237 |
+
c = 0 if t.shape[1] > 0 else None
|
| 238 |
+
b = 0 if t.shape[0] > 0 else None
|
| 239 |
+
if b is not None and c is not None:
|
| 240 |
+
return t[b, c]
|
| 241 |
+
if t.ndim == 3: # [C,H,W]
|
| 242 |
+
return t[0] if t.shape[0] >= 1 else t.mean(0)
|
| 243 |
+
return t # already 2-D or degenerate -> let caller sanitize
|
| 244 |
|
| 245 |
# ---- main call ----
|
| 246 |
def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
|
|
|
| 259 |
# dtype alignment for activations
|
| 260 |
img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
|
| 261 |
|
| 262 |
+
# build a deeper downscale ladder to survive tight VRAM
|
| 263 |
nh, nw, s = self._compute_scaled_size(H, W)
|
| 264 |
scales = [(nh, nw)]
|
| 265 |
+
# add progressive reductions until fairly small, but not tiny
|
| 266 |
if s < 1.0:
|
| 267 |
+
f = 0.85
|
| 268 |
+
cur_h, cur_w = nh, nw
|
| 269 |
+
for _ in range(6): # up to 8 attempts total
|
| 270 |
+
cur_h = max(128, int(cur_h * f))
|
| 271 |
+
cur_w = max(128, int(cur_w * f))
|
| 272 |
+
if (cur_h, cur_w) != scales[-1]:
|
| 273 |
+
scales.append((cur_h, cur_w))
|
| 274 |
+
if max(cur_h, cur_w) <= 192 or (cur_h * cur_w) <= 150_000:
|
| 275 |
+
break
|
| 276 |
|
| 277 |
last_exc = None
|
| 278 |
|
|
|
|
| 282 |
img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
|
| 283 |
msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
|
| 284 |
|
|
|
|
| 285 |
img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
|
| 286 |
m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None # [1,H,W] or None
|
| 287 |
+
mask_2d = m_1hw[0].contiguous() if m_1hw is not None else None# [H,W] or None
|
|
|
|
| 288 |
|
| 289 |
# inference with autocast + inference_mode
|
| 290 |
with torch.inference_mode():
|
|
|
|
| 316 |
out_prob = self.core.step(image=img_chw)
|
| 317 |
alpha = self._to_alpha(out_prob)
|
| 318 |
|
| 319 |
+
# ---- SAFE UPSAMPLE PATH (always 4D -> 2D) ----
|
| 320 |
if (th, tw) != (H, W):
|
| 321 |
+
a_b1hw = _to_b1hw_alpha(alpha, device=img_chw.device) # [1,1,th,tw]
|
| 322 |
+
a_b1hw = F.interpolate(a_b1hw, size=(H, W), mode="bilinear", align_corners=False) # [1,1,H,W]
|
| 323 |
+
alpha = a_b1hw[0, 0] # -> [H,W]
|
| 324 |
+
# ------------------------------------------------
|
| 325 |
|
| 326 |
return _to_2d_alpha_numpy(alpha)
|
| 327 |
|