Spaces:
Configuration error
Configuration error
agent 2.8
Browse files- models/matanyone_loader.py +73 -50
models/matanyone_loader.py
CHANGED
|
@@ -83,7 +83,7 @@ def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
|
|
| 83 |
if mask.shape[:2] != (H, W):
|
| 84 |
mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 85 |
maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
|
| 86 |
-
return maskf #
|
| 87 |
|
| 88 |
|
| 89 |
def _to_hwc01(img_bgr: np.ndarray) -> np.ndarray:
|
|
@@ -146,9 +146,10 @@ def __init__(self, device: Optional[str] = None, precision: str = "auto"):
|
|
| 146 |
self._api_mode = None
|
| 147 |
self._initialized = False
|
| 148 |
|
| 149 |
-
# chosen
|
| 150 |
-
self.
|
| 151 |
-
self.
|
|
|
|
| 152 |
|
| 153 |
self._lazy_init()
|
| 154 |
|
|
@@ -224,38 +225,60 @@ def _core_call(self, img_t: torch.Tensor, mask_t: Optional[torch.Tensor]):
|
|
| 224 |
return self._core.process_frame(img_t, mask_t) if mask_t is not None else self._core.process_frame(img_t)
|
| 225 |
raise MatAnyError("Internal error: unknown API mode")
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
def _run_frame(self, frame_bgr: np.ndarray, seed_hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
|
| 228 |
"""
|
| 229 |
Returns alpha matte as 2D np.float32 in [0,1].
|
| 230 |
- On first frame, try several (image,mask) layout combos and remember the winner.
|
| 231 |
-
- On later frames, use the recorded
|
| 232 |
"""
|
| 233 |
self._validate_input_frame(frame_bgr)
|
| 234 |
|
| 235 |
-
#
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
# Build mask layouts if seed is provided on first frame
|
| 240 |
-
mask_hw = None
|
| 241 |
-
mask_1hw = None
|
| 242 |
-
if is_first and seed_hw is not None:
|
| 243 |
-
if seed_hw.ndim != 2:
|
| 244 |
-
raise MatAnyError(f"Internal: seed_hw must be HW; got {seed_hw.shape}")
|
| 245 |
-
mask_hw = seed_hw.astype(np.float32, copy=False) # (H,W)
|
| 246 |
-
mask_1hw = mask_hw[None, ...] # (1,H,W)
|
| 247 |
-
|
| 248 |
-
# If layout already chosen, use it
|
| 249 |
-
if self._img_layout is not None and (not is_first):
|
| 250 |
-
img_t = None
|
| 251 |
-
if self._img_layout == "HWC":
|
| 252 |
-
img_t = torch.from_numpy(img_hwc).to(self.device, dtype=torch.float32, non_blocking=True)
|
| 253 |
-
else:
|
| 254 |
-
img_t = torch.from_numpy(img_chw).to(self.device, dtype=torch.float32, non_blocking=True)
|
| 255 |
-
|
| 256 |
with torch.no_grad(), self._maybe_amp():
|
| 257 |
out = self._core_call(img_t, None)
|
| 258 |
-
|
| 259 |
alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy() if isinstance(out, torch.Tensor) \
|
| 260 |
else np.asarray(out, dtype=np.float32)
|
| 261 |
if alpha_np.max() > 1.0:
|
|
@@ -265,32 +288,35 @@ def _run_frame(self, frame_bgr: np.ndarray, seed_hw: Optional[np.ndarray], is_fi
|
|
| 265 |
raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
|
| 266 |
return alpha_np.astype(np.float32)
|
| 267 |
|
| 268 |
-
#
|
| 269 |
-
attempts = [
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
]
|
| 277 |
-
else:
|
| 278 |
-
# Should never reach here (later frames handled above), but keep a safe default
|
| 279 |
-
attempts = [("HWC", None, img_hwc, None), ("CHW", None, img_chw, None)]
|
| 280 |
|
| 281 |
last_err = None
|
| 282 |
-
for
|
| 283 |
try:
|
| 284 |
-
img_t =
|
| 285 |
-
mask_t = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
with torch.no_grad(), self._maybe_amp():
|
| 288 |
out = self._core_call(img_t, mask_t)
|
| 289 |
|
| 290 |
-
# success → remember
|
| 291 |
-
self.
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
|
| 295 |
alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy() if isinstance(out, torch.Tensor) \
|
| 296 |
else np.asarray(out, dtype=np.float32)
|
|
@@ -303,10 +329,8 @@ def _run_frame(self, frame_bgr: np.ndarray, seed_hw: Optional[np.ndarray], is_fi
|
|
| 303 |
|
| 304 |
except Exception as e:
|
| 305 |
last_err = e
|
| 306 |
-
|
| 307 |
-
log.warning(f"[MATANY] Layout attempt failed (image={img_layout}, mask={m_layout}): {emsg}")
|
| 308 |
|
| 309 |
-
# If we’re here, all attempts failed
|
| 310 |
snap = _cuda_snapshot(self.device)
|
| 311 |
raise MatAnyError(f"MatAnyone first-frame probe failed for all layouts. Last error: {last_err} | {snap}")
|
| 312 |
|
|
@@ -386,7 +410,6 @@ def process_stream(
|
|
| 386 |
# Compose outputs
|
| 387 |
alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
|
| 388 |
alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
|
| 389 |
-
# alpha_hw already [0,1]
|
| 390 |
fg_bgr = (frame.astype(np.float32) * alpha_hw[..., None]).clip(0, 255).astype(np.uint8)
|
| 391 |
|
| 392 |
alpha_writer.write(alpha_bgr)
|
|
|
|
| 83 |
if mask.shape[:2] != (H, W):
|
| 84 |
mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 85 |
maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
|
| 86 |
+
return maskf # (H, W)
|
| 87 |
|
| 88 |
|
| 89 |
def _to_hwc01(img_bgr: np.ndarray) -> np.ndarray:
|
|
|
|
| 146 |
self._api_mode = None
|
| 147 |
self._initialized = False
|
| 148 |
|
| 149 |
+
# chosen builders after first frame succeeds
|
| 150 |
+
self._build_img = None # Callable[[np.ndarray], torch.Tensor]
|
| 151 |
+
self._build_msk = None # Optional[Callable[[np.ndarray], Optional[torch.Tensor]]]
|
| 152 |
+
self._layout_name = None
|
| 153 |
|
| 154 |
self._lazy_init()
|
| 155 |
|
|
|
|
| 225 |
return self._core.process_frame(img_t, mask_t) if mask_t is not None else self._core.process_frame(img_t)
|
| 226 |
raise MatAnyError("Internal error: unknown API mode")
|
| 227 |
|
| 228 |
+
# ---- builders for probing ----
|
| 229 |
+
def _mk_builder_bchw(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
|
| 230 |
+
def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
|
| 231 |
+
chw = _to_chw01(frame_bgr)
|
| 232 |
+
return torch.from_numpy(chw).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,3,H,W]
|
| 233 |
+
def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
|
| 234 |
+
return torch.from_numpy(seed_hw).unsqueeze(0).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,1,H,W]
|
| 235 |
+
return "BCHW+B1HW", b_img, b_msk
|
| 236 |
+
|
| 237 |
+
def _mk_builder_bchw_nomask(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
|
| 238 |
+
def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
|
| 239 |
+
chw = _to_chw01(frame_bgr)
|
| 240 |
+
return torch.from_numpy(chw).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True)
|
| 241 |
+
def b_msk(_: np.ndarray) -> Optional[torch.Tensor]:
|
| 242 |
+
return None
|
| 243 |
+
return "BCHW+None", b_img, b_msk
|
| 244 |
+
|
| 245 |
+
def _mk_builder_btchw(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
|
| 246 |
+
def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
|
| 247 |
+
chw = _to_chw01(frame_bgr)
|
| 248 |
+
return torch.from_numpy(chw).unsqueeze(0).unsqueeze(1).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,1,3,H,W]
|
| 249 |
+
def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
|
| 250 |
+
return torch.from_numpy(seed_hw).unsqueeze(0).unsqueeze(0).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,1,1,H,W]
|
| 251 |
+
return "BTCHW+B1THW", b_img, b_msk
|
| 252 |
+
|
| 253 |
+
def _mk_builder_chw(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
|
| 254 |
+
def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
|
| 255 |
+
chw = _to_chw01(frame_bgr)
|
| 256 |
+
return torch.from_numpy(chw).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [3,H,W]
|
| 257 |
+
def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
|
| 258 |
+
return torch.from_numpy(seed_hw).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,H,W]
|
| 259 |
+
return "CHW+1HW", b_img, b_msk
|
| 260 |
+
|
| 261 |
+
def _mk_builder_hwc(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
|
| 262 |
+
def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
|
| 263 |
+
hwc = _to_hwc01(frame_bgr)
|
| 264 |
+
return torch.from_numpy(hwc).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [H,W,3]
|
| 265 |
+
def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
|
| 266 |
+
return torch.from_numpy(seed_hw).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [H,W]
|
| 267 |
+
return "HWC+HW", b_img, b_msk
|
| 268 |
+
|
| 269 |
def _run_frame(self, frame_bgr: np.ndarray, seed_hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
|
| 270 |
"""
|
| 271 |
Returns alpha matte as 2D np.float32 in [0,1].
|
| 272 |
- On first frame, try several (image,mask) layout combos and remember the winner.
|
| 273 |
+
- On later frames, use the recorded builders (mask is None).
|
| 274 |
"""
|
| 275 |
self._validate_input_frame(frame_bgr)
|
| 276 |
|
| 277 |
+
# Later frames: use the memorized builders
|
| 278 |
+
if self._build_img is not None and not is_first:
|
| 279 |
+
img_t = self._build_img(frame_bgr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
with torch.no_grad(), self._maybe_amp():
|
| 281 |
out = self._core_call(img_t, None)
|
|
|
|
| 282 |
alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy() if isinstance(out, torch.Tensor) \
|
| 283 |
else np.asarray(out, dtype=np.float32)
|
| 284 |
if alpha_np.max() > 1.0:
|
|
|
|
| 288 |
raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
|
| 289 |
return alpha_np.astype(np.float32)
|
| 290 |
|
| 291 |
+
# First frame: probe combos
|
| 292 |
+
attempts = [
|
| 293 |
+
self._mk_builder_bchw(), # [1,3,H,W] + [1,1,H,W]
|
| 294 |
+
self._mk_builder_bchw_nomask(), # [1,3,H,W] + None
|
| 295 |
+
self._mk_builder_btchw(), # [1,1,3,H,W] + [1,1,1,H,W]
|
| 296 |
+
self._mk_builder_chw(), # [3,H,W] + [1,H,W]
|
| 297 |
+
self._mk_builder_hwc(), # [H,W,3] + [H,W]
|
| 298 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
last_err = None
|
| 301 |
+
for name, mk_img, mk_msk in attempts:
|
| 302 |
try:
|
| 303 |
+
img_t = mk_img(frame_bgr)
|
| 304 |
+
mask_t = None
|
| 305 |
+
if seed_hw is not None:
|
| 306 |
+
mask_t = mk_msk(seed_hw)
|
| 307 |
+
|
| 308 |
+
log.info(f"[MATANY] Trying layout: {name} | img.shape={tuple(img_t.shape)}"
|
| 309 |
+
f"{'' if mask_t is None else ' mask.shape=' + str(tuple(mask_t.shape))}")
|
| 310 |
|
| 311 |
with torch.no_grad(), self._maybe_amp():
|
| 312 |
out = self._core_call(img_t, mask_t)
|
| 313 |
|
| 314 |
+
# success → remember builders for subsequent frames
|
| 315 |
+
self._build_img = mk_img
|
| 316 |
+
# after first frame, we won't pass mask anymore
|
| 317 |
+
self._build_msk = mk_msk
|
| 318 |
+
self._layout_name = name
|
| 319 |
+
log.info(f"[MATANY] Selected layout: {name}")
|
| 320 |
|
| 321 |
alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy() if isinstance(out, torch.Tensor) \
|
| 322 |
else np.asarray(out, dtype=np.float32)
|
|
|
|
| 329 |
|
| 330 |
except Exception as e:
|
| 331 |
last_err = e
|
| 332 |
+
log.warning(f"[MATANY] Layout attempt failed ({name}): {e}")
|
|
|
|
| 333 |
|
|
|
|
| 334 |
snap = _cuda_snapshot(self.device)
|
| 335 |
raise MatAnyError(f"MatAnyone first-frame probe failed for all layouts. Last error: {last_err} | {snap}")
|
| 336 |
|
|
|
|
| 410 |
# Compose outputs
|
| 411 |
alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
|
| 412 |
alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
|
|
|
|
| 413 |
fg_bgr = (frame.astype(np.float32) * alpha_hw[..., None]).clip(0, 255).astype(np.uint8)
|
| 414 |
|
| 415 |
alpha_writer.write(alpha_bgr)
|