MogensR commited on
Commit
e2ca8f7
·
1 Parent(s): b8dd531

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +67 -18
models/loaders/matanyone_loader.py CHANGED
@@ -107,13 +107,50 @@ def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask=Fa
107
  mode = "nearest" if is_mask else "bilinear"
108
  return F.interpolate(x, size=size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def _to_2d_alpha_numpy(x) -> np.ndarray:
111
  t = torch.as_tensor(x).float()
112
  while t.ndim > 2:
113
- if t.ndim == 3:
114
- t = t[0] if t.shape[0] >= 1 else t.squeeze(0)
 
 
115
  else:
116
- t = t.squeeze()
117
  t = t.clamp_(0.0, 1.0)
118
  out = t.detach().cpu().numpy().astype(np.float32)
119
  return np.ascontiguousarray(out)
@@ -188,17 +225,22 @@ def _compute_scaled_size(self, h: int, w: int) -> Tuple[int, int, float]:
188
  return nh, nw, s
189
 
190
  def _to_alpha(self, out_prob):
 
191
  if self._has_prob_to_mask:
192
  try:
193
  return self.core.output_prob_to_mask(out_prob, matting=True)
194
  except Exception:
195
  pass
196
  t = torch.as_tensor(out_prob).float()
197
- if t.ndim == 3 and t.shape[0] >= 1:
198
- return t[0]
199
- if t.ndim >= 2:
200
- return t
201
- return torch.full((1, 1), 0.5, dtype=torch.float32, device=t.device if t.is_cuda else "cpu")
 
 
 
 
202
 
203
  # ---- main call ----
204
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
@@ -217,12 +259,20 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
217
  # dtype alignment for activations
218
  img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
219
 
220
- # initial scale + fallbacks
221
  nh, nw, s = self._compute_scaled_size(H, W)
222
  scales = [(nh, nw)]
 
223
  if s < 1.0:
224
- scales.append((max(1, int(nh * 0.85)), max(1, int(nw * 0.85))))
225
- scales.append((max(1, int(nh * 0.70)), max(1, int(nw * 0.70))))
 
 
 
 
 
 
 
226
 
227
  last_exc = None
228
 
@@ -232,11 +282,9 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
232
  img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
233
  msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
234
 
235
- # ---- IMPORTANT SHAPE CHANGES (only edit) ----
236
  img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
237
  m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None # [1,H,W] or None
238
- mask_2d = m_1hw[0].contiguous() if m_1hw is not None else None # [H,W] or None
239
- # ------------------------------------------------
240
 
241
  # inference with autocast + inference_mode
242
  with torch.inference_mode():
@@ -268,11 +316,12 @@ def __exit__(self, *args): return False
268
  out_prob = self.core.step(image=img_chw)
269
  alpha = self._to_alpha(out_prob)
270
 
271
- # upsample back to original resolution if scaled
272
  if (th, tw) != (H, W):
273
- alpha = torch.as_tensor(alpha).unsqueeze(0).unsqueeze(0).float()
274
- alpha = F.interpolate(alpha, size=(H, W), mode="bilinear", align_corners=False)
275
- alpha = alpha.squeeze(0).squeeze(0)
 
276
 
277
  return _to_2d_alpha_numpy(alpha)
278
 
 
107
  mode = "nearest" if is_mask else "bilinear"
108
  return F.interpolate(x, size=size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
109
 
110
+ def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
111
+ """
112
+ Convert any plausible alpha/prob output into [1,1,H,W] float in [0,1].
113
+ Prevents 5D/6D mishaps when upsampling.
114
+ """
115
+ t = torch.as_tensor(alpha, device=device).float()
116
+ if t.ndim == 2:
117
+ t = t.unsqueeze(0).unsqueeze(0) # -> [1,1,H,W]
118
+ elif t.ndim == 3:
119
+ # CHW or 1HW
120
+ if t.shape[0] in (1, 3, 4):
121
+ if t.shape[0] != 1:
122
+ t = t[:1] # keep first channel
123
+ t = t.unsqueeze(0) # -> [1,1,H,W]
124
+ elif t.shape[-1] in (1, 3, 4): # HWC (unexpected, but handle)
125
+ t = t[..., :1].permute(2, 0, 1).unsqueeze(0)
126
+ else:
127
+ # assume [H,W,C?] incompatible → fallback to first dim semantics
128
+ t = t[:1].unsqueeze(0)
129
+ elif t.ndim == 4:
130
+ # [B,C,H,W] → ensure C=1 and B=1
131
+ if t.shape[1] != 1:
132
+ t = t[:, :1]
133
+ if t.shape[0] != 1:
134
+ t = t[:1]
135
+ else:
136
+ # squeeze weird shapes down to [1,1,H,W] best-effort
137
+ while t.ndim > 4:
138
+ t = t.squeeze(0)
139
+ while t.ndim < 4:
140
+ t = t.unsqueeze(0)
141
+ if t.shape[1] != 1:
142
+ t = t[:, :1]
143
+ return t.clamp_(0.0, 1.0).contiguous()
144
+
145
  def _to_2d_alpha_numpy(x) -> np.ndarray:
146
  t = torch.as_tensor(x).float()
147
  while t.ndim > 2:
148
+ if t.ndim == 4 and t.shape[0] == 1 and t.shape[1] == 1:
149
+ t = t[0, 0]
150
+ elif t.ndim == 3 and t.shape[0] == 1:
151
+ t = t[0]
152
  else:
153
+ t = t.squeeze(0)
154
  t = t.clamp_(0.0, 1.0)
155
  out = t.detach().cpu().numpy().astype(np.float32)
156
  return np.ascontiguousarray(out)
 
225
  return nh, nw, s
226
 
227
  def _to_alpha(self, out_prob):
228
+ # Prefer library conversion if available
229
  if self._has_prob_to_mask:
230
  try:
231
  return self.core.output_prob_to_mask(out_prob, matting=True)
232
  except Exception:
233
  pass
234
  t = torch.as_tensor(out_prob).float()
235
+ # Normalize common cases to 2-D alpha
236
+ if t.ndim == 4: # [B,C,H,W]
237
+ c = 0 if t.shape[1] > 0 else None
238
+ b = 0 if t.shape[0] > 0 else None
239
+ if b is not None and c is not None:
240
+ return t[b, c]
241
+ if t.ndim == 3: # [C,H,W]
242
+ return t[0] if t.shape[0] >= 1 else t.mean(0)
243
+ return t # already 2-D or degenerate -> let caller sanitize
244
 
245
  # ---- main call ----
246
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
 
259
  # dtype alignment for activations
260
  img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
261
 
262
+ # build a deeper downscale ladder to survive tight VRAM
263
  nh, nw, s = self._compute_scaled_size(H, W)
264
  scales = [(nh, nw)]
265
+ # add progressive reductions until fairly small, but not tiny
266
  if s < 1.0:
267
+ f = 0.85
268
+ cur_h, cur_w = nh, nw
269
+ for _ in range(6): # up to 8 attempts total
270
+ cur_h = max(128, int(cur_h * f))
271
+ cur_w = max(128, int(cur_w * f))
272
+ if (cur_h, cur_w) != scales[-1]:
273
+ scales.append((cur_h, cur_w))
274
+ if max(cur_h, cur_w) <= 192 or (cur_h * cur_w) <= 150_000:
275
+ break
276
 
277
  last_exc = None
278
 
 
282
  img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
283
  msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
284
 
 
285
  img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
286
  m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None # [1,H,W] or None
287
+ mask_2d = m_1hw[0].contiguous() if m_1hw is not None else None# [H,W] or None
 
288
 
289
  # inference with autocast + inference_mode
290
  with torch.inference_mode():
 
316
  out_prob = self.core.step(image=img_chw)
317
  alpha = self._to_alpha(out_prob)
318
 
319
+ # ---- SAFE UPSAMPLE PATH (always 4D -> 2D) ----
320
  if (th, tw) != (H, W):
321
+ a_b1hw = _to_b1hw_alpha(alpha, device=img_chw.device) # [1,1,th,tw]
322
+ a_b1hw = F.interpolate(a_b1hw, size=(H, W), mode="bilinear", align_corners=False) # [1,1,H,W]
323
+ alpha = a_b1hw[0, 0] # -> [H,W]
324
+ # ------------------------------------------------
325
 
326
  return _to_2d_alpha_numpy(alpha)
327