MogensR commited on
Commit
bb9d73e
·
1 Parent(s): 50668fe

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +89 -81
models/loaders/matanyone_loader.py CHANGED
@@ -1,12 +1,12 @@
1
  #!/usr/bin/env python3
2
  """
3
- MatAnyone Model Loader (Hardened)
4
- - Prevents 5D (B,T,C,H,W) tensors from ever reaching conv2d.
5
  - Normalizes images to BCHW [B,C,H,W] and masks to B1HW [B,1,H,W].
6
- - If idx_mask=True, converts masks to integer labels (long) safely.
7
- - Tries unbatched then batched calls for maximum compatibility.
8
- - Resizes masks with 'nearest' to preserve label integrity.
9
- - Includes a debug_shapes() helper for quick diagnostics.
10
  """
11
 
12
  import os
@@ -24,9 +24,6 @@
24
  # ------------------------------- Utilities -------------------------------- #
25
 
26
  def _select_device(pref: str) -> str:
27
- """
28
- Resolve a safe device string. If CUDA not available, fall back to CPU.
29
- """
30
  pref = (pref or "").lower()
31
  if pref.startswith("cuda"):
32
  return "cuda" if torch.cuda.is_available() else "cpu"
@@ -36,7 +33,6 @@ def _select_device(pref: str) -> str:
36
 
37
 
38
  def _as_tensor_on_device(x, device: str) -> torch.Tensor:
39
- """Convert ndarray or Tensor to torch.Tensor on device."""
40
  if isinstance(x, torch.Tensor):
41
  return x.to(device)
42
  return torch.from_numpy(np.asarray(x)).to(device)
@@ -45,11 +41,7 @@ def _as_tensor_on_device(x, device: str) -> torch.Tensor:
45
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
46
  """
47
  Normalize input to BCHW (image) or B1HW (mask).
48
-
49
  Accepts: HWC, CHW, BCHW, BHWC, BTCHW, BTHWC, TCHW, THWC, HW.
50
- - Collapses any time/clip dimension T if present (takes t=0 if T>1).
51
- - Images returned as float32 in [0,1], shape [B,C,H,W] (C=3 or 4; C=1 expanded to 3).
52
- - Masks returned as float32 in [0,1], shape [B,1,H,W].
53
  """
54
  x = _as_tensor_on_device(x, device)
55
 
@@ -59,43 +51,39 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
59
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
60
  x = x.float()
61
 
62
- # 5D: [B,T,C,H,W] or [B,T,H,W,C] -> take first frame
63
  if x.ndim == 5:
64
  B, T = x.shape[0], x.shape[1]
65
- x = x[:, 0] if T > 0 else x.squeeze(1) # -> [B,C,H,W] or [B,H,W,C]
66
 
67
  # 4D
68
  if x.ndim == 4:
69
- # If BHWC, permute to BCHW
70
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
71
- x = x.permute(0, 3, 1, 2).contiguous()
72
 
73
  # 3D
74
  elif x.ndim == 3:
75
- # HWC -> CHW
76
- if x.shape[-1] in (1, 3, 4):
77
  x = x.permute(2, 0, 1).contiguous()
78
- x = x.unsqueeze(0) # -> BCHW
79
 
80
  # 2D
81
  elif x.ndim == 2:
82
  if is_mask:
83
- x = x.unsqueeze(0).unsqueeze(0) # -> B1HW
84
  else:
85
- x = x.unsqueeze(0).unsqueeze(0) # 1,1,H,W
86
- x = x.repeat(1, 3, 1, 1) # 1,3,H,W
87
 
88
  else:
89
  raise ValueError(f"Unsupported tensor ndim={x.ndim} for normalization")
90
 
91
- # Now x should be BCHW
92
  if is_mask:
93
- # Ensure single-channel
94
  if x.shape[1] > 1:
95
  x = x[:, :1]
96
  x = x.clamp_(0.0, 1.0).to(torch.float32)
97
  else:
98
- # Ensure reasonable channels
99
  C = x.shape[1]
100
  if C == 1:
101
  x = x.repeat(1, 3, 1, 1)
@@ -107,9 +95,7 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
107
 
108
 
109
  def _resize_mask_to(img_bchw: torch.Tensor, mask_b1hw: torch.Tensor) -> torch.Tensor:
110
- """
111
- Ensure mask spatial dims match image. Use NEAREST to keep labels crisp.
112
- """
113
  if img_bchw.shape[-2:] == mask_b1hw.shape[-2:]:
114
  return mask_b1hw
115
  import torch.nn.functional as F
@@ -117,9 +103,7 @@ def _resize_mask_to(img_bchw: torch.Tensor, mask_b1hw: torch.Tensor) -> torch.Te
117
 
118
 
119
  def debug_shapes(tag: str, image, mask) -> None:
120
- """
121
- Quick diagnostics: logs shape/dtype/min/max for image/mask.
122
- """
123
  def _info(name, t):
124
  try:
125
  tt = torch.as_tensor(t)
@@ -129,17 +113,69 @@ def _info(name, t):
129
  f"min={mn:.4f} max={mx:.4f}")
130
  except Exception as e:
131
  logger.info(f"[{tag}:{name}] type={type(t)} err={e}")
132
-
133
  _info("image", image)
134
  _info("mask", mask)
135
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # --------------------------- Boundary Wrapper ------------------------------ #
138
 
139
  class _MatAnyoneWrapper:
140
  """
141
  Thin, defensive wrapper around the MatAnyone InferenceCore.
142
- Normalizes inputs at the boundary so the core never sees >4D tensors.
143
  """
144
 
145
  def __init__(self, core: Any, device: str):
@@ -153,12 +189,6 @@ def __init__(self, core: Any, device: str):
153
  except Exception as e:
154
  logger.debug(f"MatAnyone core .to({self.device}) not applied: {e}")
155
 
156
- @staticmethod
157
- def _to_numpy(x):
158
- if isinstance(x, torch.Tensor):
159
- return x.detach().cpu().numpy()
160
- return np.asarray(x)
161
-
162
  def _normalize_pair(
163
  self, image, mask, idx_mask: bool
164
  ) -> Tuple[torch.Tensor, torch.Tensor, bool]:
@@ -169,23 +199,21 @@ def _normalize_pair(
169
 
170
  def __call__(self, image, mask, idx_mask: bool = False, **kwargs):
171
  """
172
- Preferred entry: handles normalization and robust call patterns.
173
  """
174
  img_bchw, msk_b1hw, idx_mask = self._normalize_pair(image, mask, idx_mask)
175
 
176
- # Special handling for idx_mask: convert to integer label map.
177
  if idx_mask:
178
- # Threshold -> {0,1} long; squeeze channel
179
  m_bhw = (msk_b1hw > 0.5).long()[:, 0] # [B,H,W]
180
- # Try unbatched first if B==1
181
  if img_bchw.shape[0] == 1:
182
  img_chw = img_bchw[0] # [C,H,W]
183
  m_hw = m_bhw[0] # [H,W]
184
- # Prefer step(image, mask, idx_mask=True)
185
  try:
186
  if hasattr(self.core, "step"):
187
  out = self.core.step(image=img_chw, mask=m_hw, idx_mask=True, **kwargs)
188
- return self._to_numpy(out)
189
  except Exception as e_unbatched_idx:
190
  logger.debug(f"MatAnyone unbatched idx_mask step() failed: {e_unbatched_idx}")
191
  # Batched fallback
@@ -194,21 +222,20 @@ def __call__(self, image, mask, idx_mask: bool = False, **kwargs):
194
  if hasattr(self.core, method_name):
195
  method = getattr(self.core, method_name)
196
  out = method(image=img_bchw, mask=m_bhw, idx_mask=True, **kwargs)
197
- return self._to_numpy(out)
198
  except Exception as e_batched_idx:
199
  logger.debug(f"MatAnyone {method_name} idx_mask batched call failed: {e_batched_idx}")
200
 
201
  logger.warning("MatAnyone idx_mask calls failed; returning integer mask as fallback.")
202
- return self._to_numpy(m_bhw if m_bhw.shape[0] > 1 else m_bhw[0])
203
 
204
- # Non-index soft/binary mask path
205
  try:
206
- # Try unbatched first (common CHW / 1HW)
207
  if hasattr(self.core, "step") and img_bchw.shape[0] == 1:
208
  img_chw = img_bchw[0] # [C,H,W]
209
  m_1hw = msk_b1hw[0] # [1,H,W]
210
  out = self.core.step(image=img_chw, mask=m_1hw, idx_mask=False, **kwargs)
211
- return self._to_numpy(out)
212
  except Exception as e_unbatched:
213
  logger.debug(f"MatAnyone unbatched step() failed: {e_unbatched}")
214
 
@@ -218,13 +245,13 @@ def __call__(self, image, mask, idx_mask: bool = False, **kwargs):
218
  if hasattr(self.core, method_name):
219
  method = getattr(self.core, method_name)
220
  out = method(image=img_bchw, mask=msk_b1hw, idx_mask=False, **kwargs)
221
- return self._to_numpy(out)
222
  except Exception as e_batched:
223
  logger.debug(f"MatAnyone {method_name} batched call failed: {e_batched}")
224
 
225
  logger.warning("MatAnyone calls failed; returning input mask as fallback.")
226
- return self._to_numpy(msk_b1hw.squeeze(1)) # [B,H,W] or [H,W] if squeezed
227
-
228
 
229
  # ------------------------------- Loader ----------------------------------- #
230
 
@@ -243,8 +270,7 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyo
243
  def load(self) -> Optional[Any]:
244
  """
245
  Load MatAnyone model and return a callable wrapper.
246
- Returns:
247
- _MatAnyoneWrapper or None
248
  """
249
  logger.info(f"Loading MatAnyone model: {self.model_id} (device={self.device})")
250
 
@@ -272,9 +298,7 @@ def load(self) -> Optional[Any]:
272
  return None
273
 
274
  def _load_official(self) -> Optional[Any]:
275
- """
276
- Load using the official MatAnyone API and wrap with boundary normalizer.
277
- """
278
  try:
279
  from matanyone import InferenceCore # type: ignore
280
  except Exception as e:
@@ -287,28 +311,15 @@ def _load_official(self) -> Optional[Any]:
287
 
288
  def _load_fallback(self) -> Optional[Any]:
289
  """Create a minimal fallback that smooths/returns the mask."""
290
-
291
  class _FallbackCore:
292
  def step(self, image, mask, idx_mask: bool = False, **kwargs):
293
- # Convert mask to numpy
294
- if isinstance(mask, torch.Tensor):
295
- mask_np = mask.detach().cpu().numpy()
296
- else:
297
- mask_np = np.asarray(mask)
298
  try:
299
  import cv2
300
- if mask_np.ndim == 2:
301
- return cv2.GaussianBlur(mask_np, (5, 5), 1.0)
302
- if mask_np.ndim == 3:
303
- # Handle CHW-style smoothing (per-channel)
304
- if mask_np.shape[0] in (1, 3, 4):
305
- sm = np.empty_like(mask_np)
306
- for i in range(mask_np.shape[0]):
307
- sm[i] = cv2.GaussianBlur(mask_np[i], (5, 5), 1.0)
308
- return sm
309
- return mask_np
310
  except Exception:
311
- return mask_np
312
 
313
  def process(self, image, mask, **kwargs):
314
  return self.step(image, mask, **kwargs)
@@ -320,7 +331,6 @@ def process(self, image, mask, **kwargs):
320
  # --------------------------- Housekeeping --------------------------- #
321
 
322
  def cleanup(self):
323
- """Clean up resources."""
324
  if self.model:
325
  try:
326
  del self.model
@@ -331,7 +341,6 @@ def cleanup(self):
331
  torch.cuda.empty_cache()
332
 
333
  def get_info(self) -> Dict[str, Any]:
334
- """Get loader information."""
335
  return {
336
  "loaded": self.model is not None,
337
  "model_id": self.model_id,
@@ -340,7 +349,6 @@ def get_info(self) -> Dict[str, Any]:
340
  "model_type": type(self.model).__name__ if self.model else None,
341
  }
342
 
343
- # Optional: instance-level shape debugging hook
344
  def debug_shapes(self, image, mask, tag: str = ""):
345
  debug_shapes(tag, image, mask)
346
-
 
1
  #!/usr/bin/env python3
2
  """
3
+ MatAnyone Model Loader (Hardened v2)
4
+ - Prevents 5D (B,T,C,H,W) tensors from reaching conv2d.
5
  - Normalizes images to BCHW [B,C,H,W] and masks to B1HW [B,1,H,W].
6
+ - idx_mask=True -> integer label map, but final output still a 2-D [H,W] mask for OpenCV.
7
+ - ALWAYS returns a 2-D, contiguous, float32 mask [H,W] to downstream code.
8
+ - Tries unbatched then batched calls; resizes masks with NEAREST to preserve labels.
9
+ - Includes debug_shapes() for quick diagnostics.
10
  """
11
 
12
  import os
 
24
  # ------------------------------- Utilities -------------------------------- #
25
 
26
  def _select_device(pref: str) -> str:
 
 
 
27
  pref = (pref or "").lower()
28
  if pref.startswith("cuda"):
29
  return "cuda" if torch.cuda.is_available() else "cpu"
 
33
 
34
 
35
  def _as_tensor_on_device(x, device: str) -> torch.Tensor:
 
36
  if isinstance(x, torch.Tensor):
37
  return x.to(device)
38
  return torch.from_numpy(np.asarray(x)).to(device)
 
41
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
42
  """
43
  Normalize input to BCHW (image) or B1HW (mask).
 
44
  Accepts: HWC, CHW, BCHW, BHWC, BTCHW, BTHWC, TCHW, THWC, HW.
 
 
 
45
  """
46
  x = _as_tensor_on_device(x, device)
47
 
 
51
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
52
  x = x.float()
53
 
54
+ # 5D: [B,T,C,H,W] or [B,T,H,W,C] -> take first frame
55
  if x.ndim == 5:
56
  B, T = x.shape[0], x.shape[1]
57
+ x = x[:, 0] if T > 0 else x.squeeze(1)
58
 
59
  # 4D
60
  if x.ndim == 4:
 
61
  if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
62
+ x = x.permute(0, 3, 1, 2).contiguous() # BHWC -> BCHW
63
 
64
  # 3D
65
  elif x.ndim == 3:
66
+ if x.shape[-1] in (1, 3, 4): # HWC -> CHW
 
67
  x = x.permute(2, 0, 1).contiguous()
68
+ x = x.unsqueeze(0) # -> BCHW
69
 
70
  # 2D
71
  elif x.ndim == 2:
72
  if is_mask:
73
+ x = x.unsqueeze(0).unsqueeze(0) # -> B1HW
74
  else:
75
+ x = x.unsqueeze(0).unsqueeze(0) # 1,1,H,W
76
+ x = x.repeat(1, 3, 1, 1) # 1,3,H,W
77
 
78
  else:
79
  raise ValueError(f"Unsupported tensor ndim={x.ndim} for normalization")
80
 
81
+ # Finalize channels / clamp
82
  if is_mask:
 
83
  if x.shape[1] > 1:
84
  x = x[:, :1]
85
  x = x.clamp_(0.0, 1.0).to(torch.float32)
86
  else:
 
87
  C = x.shape[1]
88
  if C == 1:
89
  x = x.repeat(1, 3, 1, 1)
 
95
 
96
 
97
  def _resize_mask_to(img_bchw: torch.Tensor, mask_b1hw: torch.Tensor) -> torch.Tensor:
98
+ """Ensure mask spatial dims match image. Use NEAREST to keep labels crisp."""
 
 
99
  if img_bchw.shape[-2:] == mask_b1hw.shape[-2:]:
100
  return mask_b1hw
101
  import torch.nn.functional as F
 
103
 
104
 
105
  def debug_shapes(tag: str, image, mask) -> None:
106
+ """Quick diagnostics: logs shape/dtype/min/max for image/mask."""
 
 
107
  def _info(name, t):
108
  try:
109
  tt = torch.as_tensor(t)
 
113
  f"min={mn:.4f} max={mx:.4f}")
114
  except Exception as e:
115
  logger.info(f"[{tag}:{name}] type={type(t)} err={e}")
 
116
  _info("image", image)
117
  _info("mask", mask)
118
 
119
 
120
+ def _to_2d_numpy_mask(x) -> np.ndarray:
121
+ """
122
+ Convert any tensor/ndarray mask to a 2-D contiguous float32 array [H,W] in [0,1].
123
+ Handles inputs like: B1HW, BCHW, 1HW, CHW, HWC, HW, etc.
124
+ """
125
+ if isinstance(x, torch.Tensor):
126
+ t = x.detach()
127
+ else:
128
+ t = torch.as_tensor(x)
129
+
130
+ # Bring to float in [0,1] if likely 0..255
131
+ if t.dtype == torch.uint8:
132
+ t = t.float().div_(255.0)
133
+ elif t.dtype in (torch.int16, torch.int32, torch.int64):
134
+ t = t.float()
135
+ else:
136
+ t = t.float()
137
+
138
+ # Reduce dimensions to [H,W]
139
+ if t.ndim == 4: # e.g., [B, C, H, W]
140
+ if t.shape[0] > 1:
141
+ t = t[0]
142
+ # now [C,H,W]
143
+ if t.shape[0] > 1: # multiple channels -> take first (or could mean)
144
+ t = t[0]
145
+ else:
146
+ t = t[0] # squeeze channel -> [H,W]
147
+ elif t.ndim == 3:
148
+ # Could be [1,H,W], [C,H,W], or [H,W,1]
149
+ if t.shape[0] in (1, 3, 4): # CHW/1HW
150
+ t = t[0] # -> [H,W] (first channel)
151
+ elif t.shape[-1] == 1: # HWC with single channel
152
+ t = t[..., 0] # -> [H,W]
153
+ else:
154
+ # Unknown 3D -> take first slice
155
+ t = t[0]
156
+ elif t.ndim == 2:
157
+ pass # already [H,W]
158
+ else:
159
+ # Any other: try to squeeze to 2-D
160
+ t = t.squeeze()
161
+ if t.ndim != 2:
162
+ # fallback to a tiny neutral mask
163
+ h = int(t.shape[-2]) if t.ndim >= 2 else 512
164
+ w = int(t.shape[-1]) if t.ndim >= 2 else 512
165
+ t = torch.full((h, w), 0.5, dtype=torch.float32)
166
+
167
+ # Clamp and convert to contiguous numpy
168
+ t = t.clamp_(0.0, 1.0)
169
+ m = t.cpu().numpy().astype(np.float32)
170
+ return np.ascontiguousarray(m)
171
+
172
+
173
  # --------------------------- Boundary Wrapper ------------------------------ #
174
 
175
  class _MatAnyoneWrapper:
176
  """
177
  Thin, defensive wrapper around the MatAnyone InferenceCore.
178
+ Normalizes inputs at the boundary and always outputs a 2-D mask for OpenCV.
179
  """
180
 
181
  def __init__(self, core: Any, device: str):
 
189
  except Exception as e:
190
  logger.debug(f"MatAnyone core .to({self.device}) not applied: {e}")
191
 
 
 
 
 
 
 
192
  def _normalize_pair(
193
  self, image, mask, idx_mask: bool
194
  ) -> Tuple[torch.Tensor, torch.Tensor, bool]:
 
199
 
200
  def __call__(self, image, mask, idx_mask: bool = False, **kwargs):
201
  """
202
+ Entry point: returns a 2-D float32 mask [H,W] for downstream OpenCV.
203
  """
204
  img_bchw, msk_b1hw, idx_mask = self._normalize_pair(image, mask, idx_mask)
205
 
206
+ # idx_mask path -> integer labels; still output 2-D for downstream
207
  if idx_mask:
 
208
  m_bhw = (msk_b1hw > 0.5).long()[:, 0] # [B,H,W]
209
+ # Try unbatched if B==1
210
  if img_bchw.shape[0] == 1:
211
  img_chw = img_bchw[0] # [C,H,W]
212
  m_hw = m_bhw[0] # [H,W]
 
213
  try:
214
  if hasattr(self.core, "step"):
215
  out = self.core.step(image=img_chw, mask=m_hw, idx_mask=True, **kwargs)
216
+ return _to_2d_numpy_mask(out)
217
  except Exception as e_unbatched_idx:
218
  logger.debug(f"MatAnyone unbatched idx_mask step() failed: {e_unbatched_idx}")
219
  # Batched fallback
 
222
  if hasattr(self.core, method_name):
223
  method = getattr(self.core, method_name)
224
  out = method(image=img_bchw, mask=m_bhw, idx_mask=True, **kwargs)
225
+ return _to_2d_numpy_mask(out)
226
  except Exception as e_batched_idx:
227
  logger.debug(f"MatAnyone {method_name} idx_mask batched call failed: {e_batched_idx}")
228
 
229
  logger.warning("MatAnyone idx_mask calls failed; returning integer mask as fallback.")
230
+ return _to_2d_numpy_mask(m_bhw)
231
 
232
+ # Non-index mask path (soft/binary)
233
  try:
 
234
  if hasattr(self.core, "step") and img_bchw.shape[0] == 1:
235
  img_chw = img_bchw[0] # [C,H,W]
236
  m_1hw = msk_b1hw[0] # [1,H,W]
237
  out = self.core.step(image=img_chw, mask=m_1hw, idx_mask=False, **kwargs)
238
+ return _to_2d_numpy_mask(out)
239
  except Exception as e_unbatched:
240
  logger.debug(f"MatAnyone unbatched step() failed: {e_unbatched}")
241
 
 
245
  if hasattr(self.core, method_name):
246
  method = getattr(self.core, method_name)
247
  out = method(image=img_bchw, mask=msk_b1hw, idx_mask=False, **kwargs)
248
+ return _to_2d_numpy_mask(out)
249
  except Exception as e_batched:
250
  logger.debug(f"MatAnyone {method_name} batched call failed: {e_batched}")
251
 
252
  logger.warning("MatAnyone calls failed; returning input mask as fallback.")
253
+ # Return a valid 2-D mask even on total failure
254
+ return _to_2d_numpy_mask(msk_b1hw)
255
 
256
  # ------------------------------- Loader ----------------------------------- #
257
 
 
270
  def load(self) -> Optional[Any]:
271
  """
272
  Load MatAnyone model and return a callable wrapper.
273
+ Returns: _MatAnyoneWrapper or None
 
274
  """
275
  logger.info(f"Loading MatAnyone model: {self.model_id} (device={self.device})")
276
 
 
298
  return None
299
 
300
  def _load_official(self) -> Optional[Any]:
301
+ """Load using the official MatAnyone API and wrap with boundary normalizer."""
 
 
302
  try:
303
  from matanyone import InferenceCore # type: ignore
304
  except Exception as e:
 
311
 
312
  def _load_fallback(self) -> Optional[Any]:
313
  """Create a minimal fallback that smooths/returns the mask."""
 
314
  class _FallbackCore:
315
  def step(self, image, mask, idx_mask: bool = False, **kwargs):
316
+ # Convert to 2-D numpy mask as final step
317
+ m2d = _to_2d_numpy_mask(mask)
 
 
 
318
  try:
319
  import cv2
320
+ return cv2.GaussianBlur(m2d, (5, 5), 1.0)
 
 
 
 
 
 
 
 
 
321
  except Exception:
322
+ return m2d
323
 
324
  def process(self, image, mask, **kwargs):
325
  return self.step(image, mask, **kwargs)
 
331
  # --------------------------- Housekeeping --------------------------- #
332
 
333
  def cleanup(self):
 
334
  if self.model:
335
  try:
336
  del self.model
 
341
  torch.cuda.empty_cache()
342
 
343
  def get_info(self) -> Dict[str, Any]:
 
344
  return {
345
  "loaded": self.model is not None,
346
  "model_id": self.model_id,
 
349
  "model_type": type(self.model).__name__ if self.model else None,
350
  }
351
 
352
+ # Optional: instance-level shape debugging
353
  def debug_shapes(self, image, mask, tag: str = ""):
354
  debug_shapes(tag, image, mask)