Update models/loaders/matanyone_loader.py
Browse files- models/loaders/matanyone_loader.py +187 -126
models/loaders/matanyone_loader.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
"""
|
| 4 |
-
MatAnyone Loader - Stable Callable Wrapper for InferenceCore
|
| 5 |
-
===========================================================
|
| 6 |
-
|
| 7 |
-
-
|
| 8 |
-
-
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import os
|
|
@@ -32,92 +36,170 @@ def _to_float01_np(arr: np.ndarray) -> np.ndarray:
|
|
| 32 |
if arr.dtype == np.uint8:
|
| 33 |
arr = arr.astype(np.float32) / 255.0
|
| 34 |
else:
|
| 35 |
-
arr = arr.astype(np.float32)
|
| 36 |
-
# Clamp for safety
|
| 37 |
np.clip(arr, 0.0, 1.0, out=arr)
|
| 38 |
return arr
|
| 39 |
|
| 40 |
|
| 41 |
-
def
|
| 42 |
"""
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
if torch.is_tensor(image):
|
| 47 |
t = image
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
t = t.unsqueeze(0)
|
| 54 |
else:
|
| 55 |
-
raise ValueError(f"
|
|
|
|
| 56 |
t = t.to(dtype=torch.float32)
|
| 57 |
-
# If likely 0-255, scale; otherwise clamp to [0,1]
|
| 58 |
if torch.max(t) > 1.5:
|
| 59 |
t = t / 255.0
|
| 60 |
t = torch.clamp(t, 0.0, 1.0)
|
|
|
|
| 61 |
return t
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
| 69 |
pass
|
|
|
|
|
|
|
| 70 |
else:
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
-
def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
| 77 |
"""
|
| 78 |
-
Convert mask to torch.FloatTensor 1HW in [0,1].
|
| 79 |
-
Accepts
|
| 80 |
"""
|
|
|
|
|
|
|
|
|
|
| 81 |
if torch.is_tensor(mask):
|
| 82 |
m = mask
|
| 83 |
-
if m.ndim ==
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
m = m.permute(2, 0, 1)
|
| 90 |
else:
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
else:
|
| 93 |
-
raise ValueError(f"
|
|
|
|
| 94 |
m = m.to(dtype=torch.float32)
|
| 95 |
if torch.max(m) > 1.5:
|
| 96 |
m = m / 255.0
|
| 97 |
m = torch.clamp(m, 0.0, 1.0)
|
|
|
|
| 98 |
return m
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
arr = arr.transpose(2, 0, 1)
|
| 108 |
-
else:
|
| 109 |
-
raise ValueError(f"Mask has too many channels: {arr.shape}")
|
| 110 |
else:
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
|
| 117 |
-
"""
|
| 118 |
-
Extract a 2D alpha (H,W) float32 [0,1] from a variety of possible outputs.
|
| 119 |
-
Accepts numpy/tensor with shapes: HW, 1HW, CHW(C>=1), BHWC, BCHW, etc.
|
| 120 |
-
"""
|
| 121 |
if result is None:
|
| 122 |
return np.full((512, 512), 0.5, dtype=np.float32)
|
| 123 |
|
|
@@ -125,27 +207,23 @@ def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
|
|
| 125 |
result = result.detach().float().cpu()
|
| 126 |
|
| 127 |
arr = np.asarray(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
if arr.ndim == 2:
|
| 129 |
alpha = arr
|
| 130 |
elif arr.ndim == 3:
|
| 131 |
-
|
| 132 |
-
if arr.shape[0] in (1, 3, 4): # CHW
|
| 133 |
alpha = arr[0]
|
| 134 |
-
elif arr.shape[-1] in (1, 3, 4):
|
| 135 |
alpha = arr[..., 0]
|
| 136 |
else:
|
| 137 |
-
|
| 138 |
-
alpha = arr[0]
|
| 139 |
-
elif arr.ndim == 4:
|
| 140 |
-
# Batch first: BxCxHxW or BxHxWxC
|
| 141 |
-
if arr.shape[1] in (1, 3, 4): # BCHW
|
| 142 |
-
alpha = arr[0, 0]
|
| 143 |
-
elif arr.shape[-1] in (1, 3, 4): # BHWC
|
| 144 |
-
alpha = arr[0, ..., 0]
|
| 145 |
-
else:
|
| 146 |
-
alpha = arr[0, 0]
|
| 147 |
else:
|
| 148 |
-
#
|
| 149 |
alpha = np.full((512, 512), 0.5, dtype=np.float32)
|
| 150 |
|
| 151 |
alpha = alpha.astype(np.float32, copy=False)
|
|
@@ -154,41 +232,29 @@ def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
|
|
| 154 |
|
| 155 |
|
| 156 |
def _hw_from_image_like(x: Union[np.ndarray, torch.Tensor]) -> Tuple[int, int]:
|
| 157 |
-
"""Best-effort
|
| 158 |
if torch.is_tensor(x):
|
| 159 |
shape = tuple(x.shape)
|
| 160 |
-
# Handle CHW / HWC / BCHW / BHWC / HW
|
| 161 |
-
if len(shape) == 2: # HW
|
| 162 |
-
return shape[0], shape[1]
|
| 163 |
-
if len(shape) == 3:
|
| 164 |
-
if shape[0] in (1, 3, 4): # CHW
|
| 165 |
-
return shape[1], shape[2]
|
| 166 |
-
if shape[-1] in (1, 3, 4): # HWC
|
| 167 |
-
return shape[0], shape[1]
|
| 168 |
-
if len(shape) == 4:
|
| 169 |
-
# Assume batch first
|
| 170 |
-
b, c_or_h, h_or_w, maybe_w = shape
|
| 171 |
-
# Try BCHW
|
| 172 |
-
if shape[1] in (1, 3, 4):
|
| 173 |
-
return shape[2], shape[3]
|
| 174 |
-
# Try BHWC
|
| 175 |
-
return shape[1], shape[2]
|
| 176 |
-
return 512, 512
|
| 177 |
else:
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
if
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
|
| 194 |
# --------------------------- Callable Wrapper ---------------------------
|
|
@@ -201,6 +267,7 @@ class MatAnyoneCallableWrapper:
|
|
| 201 |
- First call SHOULD include a mask (1HW). If not, returns neutral 0.5 alpha.
|
| 202 |
- Subsequent calls do not require mask.
|
| 203 |
- Returns 2D alpha (H,W) float32 in [0,1].
|
|
|
|
| 204 |
"""
|
| 205 |
|
| 206 |
def __init__(self, inference_core, device: str = "cuda", mixed_precision: Optional[str] = "fp16"):
|
|
@@ -213,7 +280,7 @@ def _maybe_autocast(self):
|
|
| 213 |
if self.device == "cuda" and self.mixed_precision in ("fp16", "bf16"):
|
| 214 |
dtype = torch.float16 if self.mixed_precision == "fp16" else torch.bfloat16
|
| 215 |
return torch.autocast(device_type="cuda", dtype=dtype)
|
| 216 |
-
# no-op
|
| 217 |
class _NullCtx:
|
| 218 |
def __enter__(self): return None
|
| 219 |
def __exit__(self, *exc): return False
|
|
@@ -221,9 +288,8 @@ def __exit__(self, *exc): return False
|
|
| 221 |
|
| 222 |
def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
| 223 |
try:
|
| 224 |
-
# Preprocess
|
| 225 |
-
img_chw = _ensure_chw_float01(image).to(self.device, non_blocking=True)
|
| 226 |
-
img_bchw = img_chw.unsqueeze(0) # B=1
|
| 227 |
|
| 228 |
if not self.initialized:
|
| 229 |
if mask is None:
|
|
@@ -231,15 +297,14 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
|
| 231 |
logger.warning("MatAnyone first frame called without mask; returning neutral alpha.")
|
| 232 |
return np.full((h, w), 0.5, dtype=np.float32)
|
| 233 |
|
| 234 |
-
m_1hw = _ensure_1hw_float01(mask).to(self.device, non_blocking=True)
|
| 235 |
-
m_b1hw = m_1hw.unsqueeze(0) # B=1
|
| 236 |
|
| 237 |
with torch.inference_mode():
|
| 238 |
with self._maybe_autocast():
|
| 239 |
if hasattr(self.core, "step"):
|
| 240 |
-
result = self.core.step(image=
|
| 241 |
elif hasattr(self.core, "process_frame"):
|
| 242 |
-
result = self.core.process_frame(
|
| 243 |
else:
|
| 244 |
logger.warning("InferenceCore has no recognized frame API; echoing input mask.")
|
| 245 |
return _alpha_from_result(mask)
|
|
@@ -251,9 +316,9 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
|
| 251 |
with torch.inference_mode():
|
| 252 |
with self._maybe_autocast():
|
| 253 |
if hasattr(self.core, "step"):
|
| 254 |
-
result = self.core.step(image=
|
| 255 |
elif hasattr(self.core, "process_frame"):
|
| 256 |
-
result = self.core.process_frame(
|
| 257 |
else:
|
| 258 |
h, w = _hw_from_image_like(image)
|
| 259 |
logger.warning("InferenceCore has no recognized frame API on subsequent call; returning neutral alpha.")
|
|
@@ -297,7 +362,7 @@ class MatAnyoneLoader:
|
|
| 297 |
Usage:
|
| 298 |
loader = MatAnyoneLoader(device="cuda")
|
| 299 |
session = loader.load() # callable
|
| 300 |
-
alpha = session(frame, first_frame_mask) #
|
| 301 |
"""
|
| 302 |
|
| 303 |
def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache",
|
|
@@ -346,13 +411,9 @@ def _try_build_core(self):
|
|
| 346 |
logger.debug(f"ctor(model_id, device, cache_dir) failed: {e}")
|
| 347 |
|
| 348 |
# 3) Minimal ctor
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
return core
|
| 353 |
-
except Exception as e:
|
| 354 |
-
logger.debug(f"ctor(model_id) failed: {e}")
|
| 355 |
-
raise # Propagate last error
|
| 356 |
|
| 357 |
def load(self) -> Optional[MatAnyoneCallableWrapper]:
|
| 358 |
"""Load MatAnyone and return the callable wrapper."""
|
|
@@ -364,7 +425,7 @@ def load(self) -> Optional[MatAnyoneCallableWrapper]:
|
|
| 364 |
|
| 365 |
try:
|
| 366 |
self.processor = self._try_build_core()
|
| 367 |
-
#
|
| 368 |
try:
|
| 369 |
if hasattr(self.processor, "to"):
|
| 370 |
self.processor.to(self.device)
|
|
@@ -445,7 +506,7 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
|
| 445 |
return self.wrapper(image, mask, **kwargs)
|
| 446 |
|
| 447 |
|
| 448 |
-
# Backwards compatibility alias
|
| 449 |
_MatAnyoneSession = MatAnyoneCallableWrapper
|
| 450 |
|
| 451 |
__all__ = ["MatAnyoneLoader", "_MatAnyoneSession", "MatAnyoneCallableWrapper"]
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
"""
|
| 4 |
+
MatAnyone Loader - Stable Callable Wrapper for InferenceCore (extra-dim stripping)
|
| 5 |
+
=================================================================================
|
| 6 |
+
|
| 7 |
+
- Always call InferenceCore UNBATCHED:
|
| 8 |
+
image -> CHW float32 [0,1]
|
| 9 |
+
mask -> 1HW float32 [0,1]
|
| 10 |
+
- Aggressively strip extra dims:
|
| 11 |
+
e.g. [B,T,C,H,W] -> [C,H,W] (use first slice when B/T > 1 with a warning)
|
| 12 |
+
e.g. [B,C,H,W] -> [C,H,W]
|
| 13 |
+
e.g. [H,W,C,1] -> [H,W,C]
|
| 14 |
+
- Optional CUDA mixed precision (fp16/bf16)
|
| 15 |
+
- Robust alpha extraction -> (H,W) float32 [0,1]
|
| 16 |
"""
|
| 17 |
|
| 18 |
import os
|
|
|
|
| 36 |
if arr.dtype == np.uint8:
|
| 37 |
arr = arr.astype(np.float32) / 255.0
|
| 38 |
else:
|
| 39 |
+
arr = arr.astype(np.float32, copy=False)
|
|
|
|
| 40 |
np.clip(arr, 0.0, 1.0, out=arr)
|
| 41 |
return arr
|
| 42 |
|
| 43 |
|
| 44 |
+
def _strip_leading_extras_to_ndim(x: Union[np.ndarray, torch.Tensor], target_ndim: int) -> Union[np.ndarray, torch.Tensor]:
|
| 45 |
"""
|
| 46 |
+
Reduce x to at most target_ndim by removing leading dims.
|
| 47 |
+
- If a leading dim == 1, squeeze it.
|
| 48 |
+
- If a leading dim > 1, take the first slice and log a warning.
|
| 49 |
+
Repeat until ndim <= target_ndim.
|
| 50 |
"""
|
| 51 |
+
is_tensor = torch.is_tensor(x)
|
| 52 |
+
get_shape = (lambda t: tuple(t.shape)) if is_tensor else (lambda a: a.shape)
|
| 53 |
+
index_first = (lambda t: t[0]) if is_tensor else (lambda a: a[0])
|
| 54 |
+
squeeze_first = (lambda t: t.squeeze(0)) if is_tensor else (lambda a: np.squeeze(a, axis=0))
|
| 55 |
+
|
| 56 |
+
while len(get_shape(x)) > target_ndim:
|
| 57 |
+
dim0 = get_shape(x)[0]
|
| 58 |
+
if dim0 == 1:
|
| 59 |
+
x = squeeze_first(x)
|
| 60 |
+
else:
|
| 61 |
+
logger.warning(f"Input has extra leading dim >1 ({dim0}); taking the first slice.")
|
| 62 |
+
x = index_first(x)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _ensure_chw_float01(image: Union[np.ndarray, torch.Tensor], *, name: str = "image") -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
Convert image to torch.FloatTensor CHW in [0,1], stripping extras.
|
| 69 |
+
Accepts shapes up to 5D (e.g. B,T,C,H,W / B,C,H,W / H,W,C / CHW / HW / ...).
|
| 70 |
+
If ambiguous multi-channel, picks first channel with a warning.
|
| 71 |
+
"""
|
| 72 |
+
orig_shape = tuple(image.shape) if not torch.is_tensor(image) else tuple(image.shape)
|
| 73 |
+
# Reduce to <= 3 dims
|
| 74 |
+
image = _strip_leading_extras_to_ndim(image, 3)
|
| 75 |
+
|
| 76 |
if torch.is_tensor(image):
|
| 77 |
t = image
|
| 78 |
+
# Convert 4D (rare if caller passes) once more
|
| 79 |
+
if t.ndim == 4:
|
| 80 |
+
t = _strip_leading_extras_to_ndim(t, 3)
|
| 81 |
+
|
| 82 |
+
if t.ndim == 3:
|
| 83 |
+
c0, c1, c2 = t.shape
|
| 84 |
+
if c0 in (1, 3, 4):
|
| 85 |
+
# CHW
|
| 86 |
+
pass
|
| 87 |
+
elif c2 in (1, 3, 4):
|
| 88 |
+
# HWC -> CHW
|
| 89 |
+
t = t.permute(2, 0, 1)
|
| 90 |
+
else:
|
| 91 |
+
# Ambiguous, assume HWC-like and take first channel after moving to CHW
|
| 92 |
+
logger.warning(f"{name}: ambiguous 3D shape {tuple(t.shape)}; attempting HWC->CHW then selecting first channel.")
|
| 93 |
+
t = t.permute(2, 0, 1)
|
| 94 |
+
if t.shape[0] > 1:
|
| 95 |
+
t = t[0]
|
| 96 |
+
t = t.unsqueeze(0) # back to 1HW
|
| 97 |
+
elif t.ndim == 2:
|
| 98 |
+
# HW -> 1HW
|
| 99 |
t = t.unsqueeze(0)
|
| 100 |
else:
|
| 101 |
+
raise ValueError(f"{name}: unsupported tensor dims {tuple(t.shape)} after stripping.")
|
| 102 |
+
|
| 103 |
t = t.to(dtype=torch.float32)
|
|
|
|
| 104 |
if torch.max(t) > 1.5:
|
| 105 |
t = t / 255.0
|
| 106 |
t = torch.clamp(t, 0.0, 1.0)
|
| 107 |
+
logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)")
|
| 108 |
return t
|
| 109 |
+
|
| 110 |
+
# numpy path
|
| 111 |
+
arr = np.asarray(image)
|
| 112 |
+
if arr.ndim == 4:
|
| 113 |
+
arr = _strip_leading_extras_to_ndim(arr, 3)
|
| 114 |
+
|
| 115 |
+
if arr.ndim == 3:
|
| 116 |
+
if arr.shape[0] in (1, 3, 4): # CHW
|
| 117 |
pass
|
| 118 |
+
elif arr.shape[-1] in (1, 3, 4): # HWC -> CHW
|
| 119 |
+
arr = arr.transpose(2, 0, 1)
|
| 120 |
else:
|
| 121 |
+
logger.warning(f"{name}: ambiguous 3D shape {arr.shape}; trying HWC->CHW and selecting first channel.")
|
| 122 |
+
arr = arr.transpose(2, 0, 1) # HWC->CHW
|
| 123 |
+
if arr.shape[0] > 1:
|
| 124 |
+
arr = arr[0:1, ...] # 1HW
|
| 125 |
+
elif arr.ndim == 2:
|
| 126 |
+
arr = arr[None, ...] # 1HW
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.")
|
| 129 |
+
|
| 130 |
+
arr = _to_float01_np(arr)
|
| 131 |
+
t = torch.from_numpy(arr)
|
| 132 |
+
logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)")
|
| 133 |
+
return t
|
| 134 |
|
| 135 |
|
| 136 |
+
def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor], *, name: str = "mask") -> torch.Tensor:
|
| 137 |
"""
|
| 138 |
+
Convert mask to torch.FloatTensor 1HW in [0,1], stripping extras.
|
| 139 |
+
Accepts up to 4D inputs; collapses leading dims; picks first slice/channel if needed.
|
| 140 |
"""
|
| 141 |
+
orig_shape = tuple(mask.shape) if not torch.is_tensor(mask) else tuple(mask.shape)
|
| 142 |
+
mask = _strip_leading_extras_to_ndim(mask, 3)
|
| 143 |
+
|
| 144 |
if torch.is_tensor(mask):
|
| 145 |
m = mask
|
| 146 |
+
if m.ndim == 3:
|
| 147 |
+
# 1HW or CHW or HWC-like
|
| 148 |
+
if m.shape[0] == 1:
|
| 149 |
+
pass # 1HW
|
| 150 |
+
elif m.shape[-1] == 1:
|
| 151 |
+
m = m.permute(2, 0, 1) # HW1 -> 1HW
|
|
|
|
| 152 |
else:
|
| 153 |
+
# If multi-channel, take first
|
| 154 |
+
logger.warning(f"{name}: multi-channel {tuple(m.shape)}; using first channel.")
|
| 155 |
+
# Assume CHW or HWC-like already normalized earlier; prefer leading as channel
|
| 156 |
+
if m.shape[0] in (3, 4):
|
| 157 |
+
m = m[0:1, ...]
|
| 158 |
+
elif m.shape[-1] in (3, 4):
|
| 159 |
+
m = m.permute(2, 0, 1)[0:1, ...]
|
| 160 |
+
else:
|
| 161 |
+
# Ambiguous -> take first along first axis and ensure 1HW
|
| 162 |
+
m = m[0:1, ...]
|
| 163 |
+
elif m.ndim == 2:
|
| 164 |
+
m = m.unsqueeze(0) # 1HW
|
| 165 |
else:
|
| 166 |
+
raise ValueError(f"{name}: unsupported tensor dims {tuple(m.shape)} after stripping.")
|
| 167 |
+
|
| 168 |
m = m.to(dtype=torch.float32)
|
| 169 |
if torch.max(m) > 1.5:
|
| 170 |
m = m / 255.0
|
| 171 |
m = torch.clamp(m, 0.0, 1.0)
|
| 172 |
+
logger.debug(f"{name}: {orig_shape} -> {tuple(m.shape)} (1HW)")
|
| 173 |
return m
|
| 174 |
+
|
| 175 |
+
# numpy path
|
| 176 |
+
arr = np.asarray(mask)
|
| 177 |
+
if arr.ndim == 3:
|
| 178 |
+
if arr.shape[0] == 1:
|
| 179 |
+
pass # 1HW
|
| 180 |
+
elif arr.shape[-1] == 1:
|
| 181 |
+
arr = arr.transpose(2, 0, 1) # HW1 -> 1HW
|
|
|
|
|
|
|
|
|
|
| 182 |
else:
|
| 183 |
+
logger.warning(f"{name}: multi-channel {arr.shape}; using first channel.")
|
| 184 |
+
if arr.shape[0] in (3, 4):
|
| 185 |
+
arr = arr[0:1, ...] # CHW -> 1HW
|
| 186 |
+
elif arr.shape[-1] in (3, 4):
|
| 187 |
+
arr = arr.transpose(2, 0, 1)[0:1, ...] # HWC -> CHW -> 1HW
|
| 188 |
+
else:
|
| 189 |
+
arr = arr[0:1, ...] # ambiguous -> 1HW by slice
|
| 190 |
+
elif arr.ndim == 2:
|
| 191 |
+
arr = arr[None, ...] # 1HW
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.")
|
| 194 |
+
|
| 195 |
+
arr = _to_float01_np(arr)
|
| 196 |
+
t = torch.from_numpy(arr)
|
| 197 |
+
logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (1HW)")
|
| 198 |
+
return t
|
| 199 |
|
| 200 |
|
| 201 |
def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
|
| 202 |
+
"""Extract a 2D alpha (H,W) float32 [0,1] from various outputs."""
|
|
|
|
|
|
|
|
|
|
| 203 |
if result is None:
|
| 204 |
return np.full((512, 512), 0.5, dtype=np.float32)
|
| 205 |
|
|
|
|
| 207 |
result = result.detach().float().cpu()
|
| 208 |
|
| 209 |
arr = np.asarray(result)
|
| 210 |
+
# Strip to <= 3 dims, then extract
|
| 211 |
+
while arr.ndim > 3:
|
| 212 |
+
if arr.shape[0] > 1:
|
| 213 |
+
logger.warning(f"Result has leading dim {arr.shape[0]}; taking first slice.")
|
| 214 |
+
arr = arr[0]
|
| 215 |
+
|
| 216 |
if arr.ndim == 2:
|
| 217 |
alpha = arr
|
| 218 |
elif arr.ndim == 3:
|
| 219 |
+
if arr.shape[0] in (1, 3, 4): # CHW -> take channel 0
|
|
|
|
| 220 |
alpha = arr[0]
|
| 221 |
+
elif arr.shape[-1] in (1, 3, 4): # HWC -> take channel 0
|
| 222 |
alpha = arr[..., 0]
|
| 223 |
else:
|
| 224 |
+
alpha = arr[0] # ambiguous
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
else:
|
| 226 |
+
# 1D or 0D shouldn't happen; fallback
|
| 227 |
alpha = np.full((512, 512), 0.5, dtype=np.float32)
|
| 228 |
|
| 229 |
alpha = alpha.astype(np.float32, copy=False)
|
|
|
|
| 232 |
|
| 233 |
|
| 234 |
def _hw_from_image_like(x: Union[np.ndarray, torch.Tensor]) -> Tuple[int, int]:
|
| 235 |
+
"""Best-effort infer (H, W) for fallback mask sizing."""
|
| 236 |
if torch.is_tensor(x):
|
| 237 |
shape = tuple(x.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
else:
|
| 239 |
+
shape = np.asarray(x).shape
|
| 240 |
+
|
| 241 |
+
# Try common orders
|
| 242 |
+
if len(shape) == 2: # HW
|
| 243 |
+
return shape[0], shape[1]
|
| 244 |
+
if len(shape) == 3:
|
| 245 |
+
if shape[0] in (1, 3, 4): # CHW
|
| 246 |
+
return shape[1], shape[2]
|
| 247 |
+
if shape[-1] in (1, 3, 4): # HWC
|
| 248 |
+
return shape[0], shape[1]
|
| 249 |
+
# Ambiguous -> treat as CHW
|
| 250 |
+
return shape[1], shape[2]
|
| 251 |
+
if len(shape) >= 4:
|
| 252 |
+
# Assume leading are batch/time; try BCHW first
|
| 253 |
+
if len(shape) >= 4 and (shape[1] in (1, 3, 4)):
|
| 254 |
+
return shape[2], shape[3]
|
| 255 |
+
# Else BHWC-ish
|
| 256 |
+
return shape[-3], shape[-2]
|
| 257 |
+
return 512, 512
|
| 258 |
|
| 259 |
|
| 260 |
# --------------------------- Callable Wrapper ---------------------------
|
|
|
|
| 267 |
- First call SHOULD include a mask (1HW). If not, returns neutral 0.5 alpha.
|
| 268 |
- Subsequent calls do not require mask.
|
| 269 |
- Returns 2D alpha (H,W) float32 in [0,1].
|
| 270 |
+
- Strips any extra dims from inputs before calling core.
|
| 271 |
"""
|
| 272 |
|
| 273 |
def __init__(self, inference_core, device: str = "cuda", mixed_precision: Optional[str] = "fp16"):
|
|
|
|
| 280 |
if self.device == "cuda" and self.mixed_precision in ("fp16", "bf16"):
|
| 281 |
dtype = torch.float16 if self.mixed_precision == "fp16" else torch.bfloat16
|
| 282 |
return torch.autocast(device_type="cuda", dtype=dtype)
|
| 283 |
+
# no-op ctx
|
| 284 |
class _NullCtx:
|
| 285 |
def __enter__(self): return None
|
| 286 |
def __exit__(self, *exc): return False
|
|
|
|
| 288 |
|
| 289 |
def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
| 290 |
try:
|
| 291 |
+
# Preprocess (unbatched)
|
| 292 |
+
img_chw = _ensure_chw_float01(image, name="image").to(self.device, non_blocking=True)
|
|
|
|
| 293 |
|
| 294 |
if not self.initialized:
|
| 295 |
if mask is None:
|
|
|
|
| 297 |
logger.warning("MatAnyone first frame called without mask; returning neutral alpha.")
|
| 298 |
return np.full((h, w), 0.5, dtype=np.float32)
|
| 299 |
|
| 300 |
+
m_1hw = _ensure_1hw_float01(mask, name="mask").to(self.device, non_blocking=True)
|
|
|
|
| 301 |
|
| 302 |
with torch.inference_mode():
|
| 303 |
with self._maybe_autocast():
|
| 304 |
if hasattr(self.core, "step"):
|
| 305 |
+
result = self.core.step(image=img_chw, mask=m_1hw, **kwargs)
|
| 306 |
elif hasattr(self.core, "process_frame"):
|
| 307 |
+
result = self.core.process_frame(img_chw, m_1hw, **kwargs)
|
| 308 |
else:
|
| 309 |
logger.warning("InferenceCore has no recognized frame API; echoing input mask.")
|
| 310 |
return _alpha_from_result(mask)
|
|
|
|
| 316 |
with torch.inference_mode():
|
| 317 |
with self._maybe_autocast():
|
| 318 |
if hasattr(self.core, "step"):
|
| 319 |
+
result = self.core.step(image=img_chw, **kwargs)
|
| 320 |
elif hasattr(self.core, "process_frame"):
|
| 321 |
+
result = self.core.process_frame(img_chw, **kwargs)
|
| 322 |
else:
|
| 323 |
h, w = _hw_from_image_like(image)
|
| 324 |
logger.warning("InferenceCore has no recognized frame API on subsequent call; returning neutral alpha.")
|
|
|
|
| 362 |
Usage:
|
| 363 |
loader = MatAnyoneLoader(device="cuda")
|
| 364 |
session = loader.load() # callable
|
| 365 |
+
alpha = session(frame, first_frame_mask) # returns (H, W) float32
|
| 366 |
"""
|
| 367 |
|
| 368 |
def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache",
|
|
|
|
| 411 |
logger.debug(f"ctor(model_id, device, cache_dir) failed: {e}")
|
| 412 |
|
| 413 |
# 3) Minimal ctor
|
| 414 |
+
core = InferenceCore(self.model_id)
|
| 415 |
+
logger.info("Loaded MatAnyone via InferenceCore(model_id) [minimal]")
|
| 416 |
+
return core
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
def load(self) -> Optional[MatAnyoneCallableWrapper]:
|
| 419 |
"""Load MatAnyone and return the callable wrapper."""
|
|
|
|
| 425 |
|
| 426 |
try:
|
| 427 |
self.processor = self._try_build_core()
|
| 428 |
+
# Optional device move
|
| 429 |
try:
|
| 430 |
if hasattr(self.processor, "to"):
|
| 431 |
self.processor.to(self.device)
|
|
|
|
| 506 |
return self.wrapper(image, mask, **kwargs)
|
| 507 |
|
| 508 |
|
| 509 |
+
# Backwards compatibility alias
|
| 510 |
_MatAnyoneSession = MatAnyoneCallableWrapper
|
| 511 |
|
| 512 |
__all__ = ["MatAnyoneLoader", "_MatAnyoneSession", "MatAnyoneCallableWrapper"]
|