MogensR commited on
Commit
47f3540
·
1 Parent(s): 874e937

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +3 -1
models/loaders/matanyone_loader.py CHANGED
@@ -217,6 +217,7 @@ def _to_alpha(self, out_prob):
217
  if t.ndim == 3:
218
  return t[0] if t.shape[0] >= 1 else t.mean(0)
219
  return t
 
220
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
221
  """
222
  Returns a 2-D float32 alpha [H,W]. On first call, provide a coarse mask.
@@ -306,7 +307,6 @@ def __exit__(self, *args): return False
306
  if msk_b1hw is not None:
307
  return _to_2d_alpha_numpy(msk_b1hw)
308
  return np.full((H, W), 0.5, dtype=np.float32)
309
-
310
  # -------------------------------- Loader ---------------------------------- #
311
 
312
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
@@ -431,6 +431,7 @@ def get_info(self) -> Dict[str, Any]:
431
 
432
  def debug_shapes(self, image, mask, tag: str = ""):
433
  debug_shapes(tag, image, mask)
 
434
  # -------------------------- Optional: Module-level symbols --------------------------
435
 
436
  __all__ = [
@@ -449,6 +450,7 @@ def debug_shapes(self, image, mask, tag: str = ""):
449
 
450
  if __name__ == "__main__":
451
  import sys
 
452
 
453
  logging.basicConfig(level=logging.INFO)
454
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
217
  if t.ndim == 3:
218
  return t[0] if t.shape[0] >= 1 else t.mean(0)
219
  return t
220
+
221
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
222
  """
223
  Returns a 2-D float32 alpha [H,W]. On first call, provide a coarse mask.
 
307
  if msk_b1hw is not None:
308
  return _to_2d_alpha_numpy(msk_b1hw)
309
  return np.full((H, W), 0.5, dtype=np.float32)
 
310
  # -------------------------------- Loader ---------------------------------- #
311
 
312
  def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
 
431
 
432
  def debug_shapes(self, image, mask, tag: str = ""):
433
  debug_shapes(tag, image, mask)
434
+
435
  # -------------------------- Optional: Module-level symbols --------------------------
436
 
437
  __all__ = [
 
450
 
451
  if __name__ == "__main__":
452
  import sys
453
+ import cv2
454
 
455
  logging.basicConfig(level=logging.INFO)
456
  device = "cuda" if torch.cuda.is_available() else "cpu"