MogensR commited on
Commit
9923851
·
1 Parent(s): 72f4052

agent 2.8

Browse files
Files changed (1) hide show
  1. models/matanyone_loader.py +73 -50
models/matanyone_loader.py CHANGED
@@ -83,7 +83,7 @@ def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
83
  if mask.shape[:2] != (H, W):
84
  mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
85
  maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
86
- return maskf # shape: (H, W), float32
87
 
88
 
89
  def _to_hwc01(img_bgr: np.ndarray) -> np.ndarray:
@@ -146,9 +146,10 @@ def __init__(self, device: Optional[str] = None, precision: str = "auto"):
146
  self._api_mode = None
147
  self._initialized = False
148
 
149
- # chosen layouts after first frame succeeds
150
- self._img_layout: Optional[str] = None # 'HWC' or 'CHW'
151
- self._mask_layout: Optional[str] = None # 'HW', '1HW', or None
 
152
 
153
  self._lazy_init()
154
 
@@ -224,38 +225,60 @@ def _core_call(self, img_t: torch.Tensor, mask_t: Optional[torch.Tensor]):
224
  return self._core.process_frame(img_t, mask_t) if mask_t is not None else self._core.process_frame(img_t)
225
  raise MatAnyError("Internal error: unknown API mode")
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  def _run_frame(self, frame_bgr: np.ndarray, seed_hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
228
  """
229
  Returns alpha matte as 2D np.float32 in [0,1].
230
  - On first frame, try several (image,mask) layout combos and remember the winner.
231
- - On later frames, use the recorded layout (mask is None).
232
  """
233
  self._validate_input_frame(frame_bgr)
234
 
235
- # Build both image layouts
236
- img_hwc = _to_hwc01(frame_bgr) # (H,W,3) float32
237
- img_chw = _to_chw01(frame_bgr) # (3,H,W) float32
238
-
239
- # Build mask layouts if seed is provided on first frame
240
- mask_hw = None
241
- mask_1hw = None
242
- if is_first and seed_hw is not None:
243
- if seed_hw.ndim != 2:
244
- raise MatAnyError(f"Internal: seed_hw must be HW; got {seed_hw.shape}")
245
- mask_hw = seed_hw.astype(np.float32, copy=False) # (H,W)
246
- mask_1hw = mask_hw[None, ...] # (1,H,W)
247
-
248
- # If layout already chosen, use it
249
- if self._img_layout is not None and (not is_first):
250
- img_t = None
251
- if self._img_layout == "HWC":
252
- img_t = torch.from_numpy(img_hwc).to(self.device, dtype=torch.float32, non_blocking=True)
253
- else:
254
- img_t = torch.from_numpy(img_chw).to(self.device, dtype=torch.float32, non_blocking=True)
255
-
256
  with torch.no_grad(), self._maybe_amp():
257
  out = self._core_call(img_t, None)
258
-
259
  alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy() if isinstance(out, torch.Tensor) \
260
  else np.asarray(out, dtype=np.float32)
261
  if alpha_np.max() > 1.0:
@@ -265,32 +288,35 @@ def _run_frame(self, frame_bgr: np.ndarray, seed_hw: Optional[np.ndarray], is_fi
265
  raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
266
  return alpha_np.astype(np.float32)
267
 
268
- # Otherwise, probe possible combos on the first frame
269
- attempts = []
270
- if is_first:
271
- attempts = [
272
- ("HWC", "HW", img_hwc, mask_hw),
273
- ("HWC", "1HW", img_hwc, mask_1hw),
274
- ("CHW", "HW", img_chw, mask_hw),
275
- ("CHW", "1HW", img_chw, mask_1hw),
276
- ]
277
- else:
278
- # Should never reach here (later frames handled above), but keep a safe default
279
- attempts = [("HWC", None, img_hwc, None), ("CHW", None, img_chw, None)]
280
 
281
  last_err = None
282
- for img_layout, m_layout, img_np, m_np in attempts:
283
  try:
284
- img_t = torch.from_numpy(img_np).to(self.device, dtype=torch.float32, non_blocking=True)
285
- mask_t = None if m_np is None else torch.from_numpy(m_np).to(self.device, dtype=torch.float32, non_blocking=True)
 
 
 
 
 
286
 
287
  with torch.no_grad(), self._maybe_amp():
288
  out = self._core_call(img_t, mask_t)
289
 
290
- # success → remember layout for subsequent frames
291
- self._img_layout = img_layout
292
- self._mask_layout = m_layout
293
- log.info(f"[MATANY] Selected layouts: image={img_layout}, mask={m_layout}")
 
 
294
 
295
  alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy() if isinstance(out, torch.Tensor) \
296
  else np.asarray(out, dtype=np.float32)
@@ -303,10 +329,8 @@ def _run_frame(self, frame_bgr: np.ndarray, seed_hw: Optional[np.ndarray], is_fi
303
 
304
  except Exception as e:
305
  last_err = e
306
- emsg = str(e)
307
- log.warning(f"[MATANY] Layout attempt failed (image={img_layout}, mask={m_layout}): {emsg}")
308
 
309
- # If we’re here, all attempts failed
310
  snap = _cuda_snapshot(self.device)
311
  raise MatAnyError(f"MatAnyone first-frame probe failed for all layouts. Last error: {last_err} | {snap}")
312
 
@@ -386,7 +410,6 @@ def process_stream(
386
  # Compose outputs
387
  alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
388
  alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
389
- # alpha_hw already [0,1]
390
  fg_bgr = (frame.astype(np.float32) * alpha_hw[..., None]).clip(0, 255).astype(np.uint8)
391
 
392
  alpha_writer.write(alpha_bgr)
 
83
  if mask.shape[:2] != (H, W):
84
  mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
85
  maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
86
+ return maskf # (H, W)
87
 
88
 
89
  def _to_hwc01(img_bgr: np.ndarray) -> np.ndarray:
 
146
  self._api_mode = None
147
  self._initialized = False
148
 
149
+ # chosen builders after first frame succeeds
150
+ self._build_img = None # Callable[[np.ndarray], torch.Tensor]
151
+ self._build_msk = None # Optional[Callable[[np.ndarray], Optional[torch.Tensor]]]
152
+ self._layout_name = None
153
 
154
  self._lazy_init()
155
 
 
225
  return self._core.process_frame(img_t, mask_t) if mask_t is not None else self._core.process_frame(img_t)
226
  raise MatAnyError("Internal error: unknown API mode")
227
 
228
+ # ---- builders for probing ----
229
+ def _mk_builder_bchw(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
230
+ def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
231
+ chw = _to_chw01(frame_bgr)
232
+ return torch.from_numpy(chw).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,3,H,W]
233
+ def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
234
+ return torch.from_numpy(seed_hw).unsqueeze(0).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,1,H,W]
235
+ return "BCHW+B1HW", b_img, b_msk
236
+
237
+ def _mk_builder_bchw_nomask(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
238
+ def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
239
+ chw = _to_chw01(frame_bgr)
240
+ return torch.from_numpy(chw).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True)
241
+ def b_msk(_: np.ndarray) -> Optional[torch.Tensor]:
242
+ return None
243
+ return "BCHW+None", b_img, b_msk
244
+
245
+ def _mk_builder_btchw(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
246
+ def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
247
+ chw = _to_chw01(frame_bgr)
248
+ return torch.from_numpy(chw).unsqueeze(0).unsqueeze(1).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,1,3,H,W]
249
+ def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
250
+ return torch.from_numpy(seed_hw).unsqueeze(0).unsqueeze(0).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,1,1,H,W]
251
+ return "BTCHW+B1THW", b_img, b_msk
252
+
253
+ def _mk_builder_chw(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
254
+ def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
255
+ chw = _to_chw01(frame_bgr)
256
+ return torch.from_numpy(chw).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [3,H,W]
257
+ def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
258
+ return torch.from_numpy(seed_hw).unsqueeze(0).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [1,H,W]
259
+ return "CHW+1HW", b_img, b_msk
260
+
261
+ def _mk_builder_hwc(self) -> Tuple[str, Callable[[np.ndarray], torch.Tensor], Callable[[np.ndarray], Optional[torch.Tensor]]]:
262
+ def b_img(frame_bgr: np.ndarray) -> torch.Tensor:
263
+ hwc = _to_hwc01(frame_bgr)
264
+ return torch.from_numpy(hwc).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [H,W,3]
265
+ def b_msk(seed_hw: np.ndarray) -> torch.Tensor:
266
+ return torch.from_numpy(seed_hw).contiguous().to(self.device, dtype=torch.float32, non_blocking=True) # [H,W]
267
+ return "HWC+HW", b_img, b_msk
268
+
269
  def _run_frame(self, frame_bgr: np.ndarray, seed_hw: Optional[np.ndarray], is_first: bool) -> np.ndarray:
270
  """
271
  Returns alpha matte as 2D np.float32 in [0,1].
272
  - On first frame, try several (image,mask) layout combos and remember the winner.
273
+ - On later frames, use the recorded builders (mask is None).
274
  """
275
  self._validate_input_frame(frame_bgr)
276
 
277
+ # Later frames: use the memorized builders
278
+ if self._build_img is not None and not is_first:
279
+ img_t = self._build_img(frame_bgr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  with torch.no_grad(), self._maybe_amp():
281
  out = self._core_call(img_t, None)
 
282
  alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy() if isinstance(out, torch.Tensor) \
283
  else np.asarray(out, dtype=np.float32)
284
  if alpha_np.max() > 1.0:
 
288
  raise MatAnyError(f"Expected 2D alpha matte; got shape {alpha_np.shape}")
289
  return alpha_np.astype(np.float32)
290
 
291
+ # First frame: probe combos
292
+ attempts = [
293
+ self._mk_builder_bchw(), # [1,3,H,W] + [1,1,H,W]
294
+ self._mk_builder_bchw_nomask(), # [1,3,H,W] + None
295
+ self._mk_builder_btchw(), # [1,1,3,H,W] + [1,1,1,H,W]
296
+ self._mk_builder_chw(), # [3,H,W] + [1,H,W]
297
+ self._mk_builder_hwc(), # [H,W,3] + [H,W]
298
+ ]
 
 
 
 
299
 
300
  last_err = None
301
+ for name, mk_img, mk_msk in attempts:
302
  try:
303
+ img_t = mk_img(frame_bgr)
304
+ mask_t = None
305
+ if seed_hw is not None:
306
+ mask_t = mk_msk(seed_hw)
307
+
308
+ log.info(f"[MATANY] Trying layout: {name} | img.shape={tuple(img_t.shape)}"
309
+ f"{'' if mask_t is None else ' mask.shape=' + str(tuple(mask_t.shape))}")
310
 
311
  with torch.no_grad(), self._maybe_amp():
312
  out = self._core_call(img_t, mask_t)
313
 
314
+ # success → remember builders for subsequent frames
315
+ self._build_img = mk_img
316
+ # after first frame, we won't pass mask anymore
317
+ self._build_msk = mk_msk
318
+ self._layout_name = name
319
+ log.info(f"[MATANY] Selected layout: {name}")
320
 
321
  alpha_np = out.detach().float().clamp(0, 1).squeeze().cpu().numpy() if isinstance(out, torch.Tensor) \
322
  else np.asarray(out, dtype=np.float32)
 
329
 
330
  except Exception as e:
331
  last_err = e
332
+ log.warning(f"[MATANY] Layout attempt failed ({name}): {e}")
 
333
 
 
334
  snap = _cuda_snapshot(self.device)
335
  raise MatAnyError(f"MatAnyone first-frame probe failed for all layouts. Last error: {last_err} | {snap}")
336
 
 
410
  # Compose outputs
411
  alpha_u8 = (alpha_hw * 255.0 + 0.5).astype(np.uint8)
412
  alpha_bgr = cv2.cvtColor(alpha_u8, cv2.COLOR_GRAY2BGR)
 
413
  fg_bgr = (frame.astype(np.float32) * alpha_hw[..., None]).clip(0, 255).astype(np.uint8)
414
 
415
  alpha_writer.write(alpha_bgr)