Update models/loaders/matanyone_loader.py
Browse files
models/loaders/matanyone_loader.py
CHANGED
|
@@ -217,6 +217,7 @@ def _to_alpha(self, out_prob):
|
|
| 217 |
if t.ndim == 3:
|
| 218 |
return t[0] if t.shape[0] >= 1 else t.mean(0)
|
| 219 |
return t
|
|
|
|
| 220 |
def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
| 221 |
"""
|
| 222 |
Returns a 2-D float32 alpha [H,W]. On first call, provide a coarse mask.
|
|
@@ -306,7 +307,6 @@ def __exit__(self, *args): return False
|
|
| 306 |
if msk_b1hw is not None:
|
| 307 |
return _to_2d_alpha_numpy(msk_b1hw)
|
| 308 |
return np.full((H, W), 0.5, dtype=np.float32)
|
| 309 |
-
|
| 310 |
# -------------------------------- Loader ---------------------------------- #
|
| 311 |
|
| 312 |
def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
|
|
@@ -431,6 +431,7 @@ def get_info(self) -> Dict[str, Any]:
|
|
| 431 |
|
| 432 |
def debug_shapes(self, image, mask, tag: str = ""):
|
| 433 |
debug_shapes(tag, image, mask)
|
|
|
|
| 434 |
# -------------------------- Optional: Module-level symbols --------------------------
|
| 435 |
|
| 436 |
__all__ = [
|
|
@@ -449,6 +450,7 @@ def debug_shapes(self, image, mask, tag: str = ""):
|
|
| 449 |
|
| 450 |
if __name__ == "__main__":
|
| 451 |
import sys
|
|
|
|
| 452 |
|
| 453 |
logging.basicConfig(level=logging.INFO)
|
| 454 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 217 |
if t.ndim == 3:
|
| 218 |
return t[0] if t.shape[0] >= 1 else t.mean(0)
|
| 219 |
return t
|
| 220 |
+
|
| 221 |
def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
| 222 |
"""
|
| 223 |
Returns a 2-D float32 alpha [H,W]. On first call, provide a coarse mask.
|
|
|
|
| 307 |
if msk_b1hw is not None:
|
| 308 |
return _to_2d_alpha_numpy(msk_b1hw)
|
| 309 |
return np.full((H, W), 0.5, dtype=np.float32)
|
|
|
|
| 310 |
# -------------------------------- Loader ---------------------------------- #
|
| 311 |
|
| 312 |
def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
|
|
|
|
| 431 |
|
| 432 |
def debug_shapes(self, image, mask, tag: str = ""):
|
| 433 |
debug_shapes(tag, image, mask)
|
| 434 |
+
|
| 435 |
# -------------------------- Optional: Module-level symbols --------------------------
|
| 436 |
|
| 437 |
__all__ = [
|
|
|
|
| 450 |
|
| 451 |
if __name__ == "__main__":
|
| 452 |
import sys
|
| 453 |
+
import cv2
|
| 454 |
|
| 455 |
logging.basicConfig(level=logging.INFO)
|
| 456 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|