MogensR commited on
Commit
c29dcc4
·
1 Parent(s): c30a2cc

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +175 -281
models/loaders/matanyone_loader.py CHANGED
@@ -1,29 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
 
 
2
  """
3
  MatAnyone Loader + Stateful Adapter (OOM-resilient, spatially robust)
4
- - Canonical HF load (MatAnyone.from_pretrained -> InferenceCore(model, cfg))
5
  - Mixed precision (bf16/fp16) with safe fallback to fp32
6
- - Autocast + inference_mode around every call
7
- - Auto downscale with progressive retry on OOM, then upsample alpha back
8
- - Always aligns mask/image dimensions before inference to avoid all size errors
9
- - Returns 2-D float32 [H,W] alpha for OpenCV
10
  """
11
 
 
 
12
  import os
13
  import time
14
  import logging
15
  import traceback
16
- from typing import Optional, Dict, Any, Tuple
17
 
18
  import numpy as np
19
  import torch
20
  import torch.nn.functional as F
21
  import inspect
22
-
23
- logger = logging.getLogger(__name__)
24
-
25
- # ------------------------- Shape & dtype utilities ------------------------- #
26
-
27
  def _select_device(pref: str) -> str:
28
  pref = (pref or "").lower()
29
  if pref.startswith("cuda"):
@@ -32,130 +52,23 @@ def _select_device(pref: str) -> str:
32
  return "cpu"
33
  return "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
- def _as_tensor_on_device(x, device: str) -> torch.Tensor:
36
- if isinstance(x, torch.Tensor):
37
- return x.to(device, non_blocking=True)
38
- return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
39
-
40
- def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
41
- """
42
- Normalize input to BCHW (image) or B1HW (mask).
43
- Accepts: HWC, CHW, BCHW, BHWC, BTCHW/BTHWC, TCHW/THWC, HW.
44
- """
45
- x = _as_tensor_on_device(x, device)
46
- if x.dtype == torch.uint8:
47
- x = x.float().div_(255.0)
48
- elif x.dtype in (torch.int16, torch.int32, torch.int64):
49
- x = x.float()
50
- if x.ndim == 5:
51
- x = x[:, 0] # -> 4D
52
- if x.ndim == 4:
53
- if x.shape[-1] in (1, 3, 4) and x.shape[1] not in (1, 3, 4):
54
- x = x.permute(0, 3, 1, 2).contiguous()
55
- elif x.ndim == 3:
56
- if x.shape[-1] in (1, 3, 4):
57
- x = x.permute(2, 0, 1).contiguous()
58
- x = x.unsqueeze(0)
59
- elif x.ndim == 2:
60
- x = x.unsqueeze(0).unsqueeze(0)
61
- if not is_mask:
62
- x = x.repeat(1, 3, 1, 1)
63
- else:
64
- raise ValueError(f"Unsupported ndim={x.ndim}")
65
- if is_mask:
66
- if x.shape[1] > 1:
67
- x = x[:, :1]
68
- x = x.clamp_(0.0, 1.0).to(torch.float32)
69
- else:
70
- if x.shape[1] == 1:
71
- x = x.repeat(1, 3, 1, 1)
72
- x = x.clamp_(0.0, 1.0)
73
- return x
74
-
75
- def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
76
- if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
77
- return img_bchw[0]
78
- return img_bchw
79
-
80
- def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
81
- if msk_b1hw is None:
82
- return None
83
- if msk_b1hw.ndim == 4 and msk_b1hw.shape[1] == 1:
84
- return msk_b1hw[0] # -> [1,H,W]
85
- if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
86
- return msk_b1hw
87
- raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
88
-
89
- def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask=False) -> Optional[torch.Tensor]:
90
- if x is None:
91
- return None
92
- if x.shape[-2:] == size_hw:
93
- return x
94
- mode = "nearest" if is_mask else "bilinear"
95
- return F.interpolate(x, size=size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
96
-
97
- def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
98
- t = torch.as_tensor(alpha, device=device).float()
99
- if t.ndim == 2:
100
- t = t.unsqueeze(0).unsqueeze(0) # -> [1,1,H,W]
101
- elif t.ndim == 3:
102
- if t.shape[0] in (1, 3, 4):
103
- if t.shape[0] != 1:
104
- t = t[:1]
105
- t = t.unsqueeze(0)
106
- elif t.shape[-1] in (1, 3, 4):
107
- t = t[..., :1].permute(2, 0, 1).unsqueeze(0)
108
- else:
109
- t = t[:1].unsqueeze(0)
110
- elif t.ndim == 4:
111
- if t.shape[1] != 1:
112
- t = t[:, :1]
113
- if t.shape[0] != 1:
114
- t = t[:1]
115
- else:
116
- while t.ndim > 4:
117
- t = t.squeeze(0)
118
- while t.ndim < 4:
119
- t = t.unsqueeze(0)
120
- if t.shape[1] != 1:
121
- t = t[:, :1]
122
- return t.clamp_(0.0, 1.0).contiguous()
123
-
124
- def _to_2d_alpha_numpy(x) -> np.ndarray:
125
- t = torch.as_tensor(x).float()
126
- while t.ndim > 2:
127
- if t.ndim == 4 and t.shape[0] == 1 and t.shape[1] == 1:
128
- t = t[0, 0]
129
- elif t.ndim == 3 and t.shape[0] == 1:
130
- t = t[0]
131
- else:
132
- t = t.squeeze(0)
133
- t = t.clamp_(0.0, 1.0)
134
- out = t.detach().cpu().numpy().astype(np.float32)
135
- return np.ascontiguousarray(out)
136
-
137
- def debug_shapes(tag: str, image, mask) -> None:
138
- def _info(name, v):
139
- try:
140
- tv = torch.as_tensor(v)
141
- mn = float(tv.min()) if tv.numel() else float("nan")
142
- mx = float(tv.max()) if tv.numel() else float("nan")
143
- logger.info(f"[{tag}:{name}] shape={tuple(tv.shape)} dtype={tv.dtype} min={mn:.4f} max={mx:.4f}")
144
- except Exception as e:
145
- logger.info(f"[{tag}:{name}] type={type(v)} err={e}")
146
- _info("image", image)
147
- _info("mask", mask)
148
-
149
- # ------------------------------ Stateful Adapter --------------------------- #
150
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  class _MatAnyoneSession:
152
  """
153
  Stateful controller around InferenceCore with OOM-resilient inference.
154
- Usage:
155
- # frame 0 (has mask):
156
- alpha0 = session(frame0_rgb01, mask01)
157
- # frames 1..N (no mask):
158
- alpha = session(frame_rgb01)
159
  """
160
  def __init__(
161
  self,
@@ -165,7 +78,7 @@ def __init__(
165
  use_autocast: bool,
166
  autocast_dtype: Optional[torch.dtype],
167
  max_edge: int = 768,
168
- target_pixels: int = 600_000, # ~775x775 cap by area
169
  ):
170
  self.core = core
171
  self.device = device
@@ -175,8 +88,9 @@ def __init__(
175
  self.max_edge = int(max_edge)
176
  self.target_pixels = int(target_pixels)
177
  self.started = False
 
178
 
179
- # feature detection
180
  try:
181
  sig = inspect.signature(self.core.step)
182
  self._has_first_frame_pred = "first_frame_pred" in sig.parameters
@@ -185,22 +99,26 @@ def __init__(
185
  self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
186
 
187
  def reset(self):
188
- try:
189
- if hasattr(self.core, "clear_memory"):
190
- self.core.clear_memory()
191
- except Exception:
192
- pass
193
- self.started = False
 
194
 
195
- def _compute_scaled_size(self, h: int, w: int) -> Tuple[int, int, float]:
196
- if h <= 0 or w <= 0:
197
- return h, w, 1.0
198
- s1 = min(1.0, self.max_edge / max(h, w))
199
- s2 = min(1.0, (self.target_pixels / (h * w)) ** 0.5) if self.target_pixels > 0 else 1.0
200
- s = min(s1, s2)
201
- nh = max(1, int(round(h * s)))
202
- nw = max(1, int(round(w * s)))
203
- return nh, nw, s
 
 
 
204
 
205
  def _to_alpha(self, out_prob):
206
  if self._has_prob_to_mask:
@@ -210,120 +128,99 @@ def _to_alpha(self, out_prob):
210
  pass
211
  t = torch.as_tensor(out_prob).float()
212
  if t.ndim == 4:
213
- c = 0 if t.shape[1] > 0 else None
214
- b = 0 if t.shape[0] > 0 else None
215
- if b is not None and c is not None:
216
- return t[b, c]
217
  if t.ndim == 3:
218
  return t[0] if t.shape[0] >= 1 else t.mean(0)
219
  return t
220
 
221
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
222
  """
223
- Returns a 2-D float32 alpha [H,W]. On first call, provide a coarse mask.
224
- Subsequent calls propagate without a mask.
 
225
  """
226
- img_bchw = _to_bchw(image, self.device, is_mask=False) # [1,C,H,W]
227
- msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
228
-
229
- H, W = img_bchw.shape[-2], img_bchw.shape[-1]
230
- # --- Guarantee same shape for mask/image at input resolution ---
231
- if msk_b1hw is not None and img_bchw.shape[-2:] != msk_b1hw.shape[-2:]:
232
- logger.warning(f"Fixing mask shape: {msk_b1hw.shape[-2:]} {img_bchw.shape[-2:]}")
233
- msk_b1hw = _resize_bchw(msk_b1hw, img_bchw.shape[-2:], is_mask=True)
234
-
235
- img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
236
- nh, nw, s = self._compute_scaled_size(H, W)
237
- scales = [(nh, nw)]
238
- if s < 1.0:
239
- f = 0.85
240
- cur_h, cur_w = nh, nw
241
- for _ in range(6):
242
- cur_h = max(128, int(cur_h * f))
243
- cur_w = max(128, int(cur_w * f))
244
- if (cur_h, cur_w) != scales[-1]:
245
- scales.append((cur_h, cur_w))
246
- if max(cur_h, cur_w) <= 192 or (cur_h * cur_w) <= 150_000:
247
- break
 
 
 
248
 
249
- last_exc = None
250
 
251
- for (th, tw) in scales:
252
- try:
253
- img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
254
- msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
255
- # --- Guarantee same shape for mask/image at each retry scale ---
256
- if msk_in is not None and img_in.shape[-2:] != msk_in.shape[-2:]:
257
- logger.warning(f"Progressive retry: resizing mask from {msk_in.shape[-2:]} to {img_in.shape[-2:]}")
258
- msk_in = _resize_bchw(msk_in, img_in.shape[-2:], is_mask=True)
259
-
260
- img_chw = _to_chw_image(img_in).contiguous()
261
- m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None
262
- mask_2d = m_1hw[0].contiguous() if m_1hw is not None else None
263
-
264
- with torch.inference_mode():
265
- if self.use_autocast:
266
- amp_ctx = torch.cuda.amp.autocast(dtype=self.autocast_dtype)
267
- else:
268
- class _NoOp:
269
- def __enter__(self): return None
270
- def __exit__(self, *args): return False
271
- amp_ctx = _NoOp()
272
- with amp_ctx:
273
- if not self.started:
274
- if mask_2d is None:
275
- logger.warning("First frame arrived without a mask; returning neutral alpha.")
276
- return np.full((H, W), 0.5, dtype=np.float32)
277
- _ = self.core.step(image=img_chw, mask=mask_2d)
278
- if self._has_first_frame_pred:
279
- out_prob = self.core.step(image=img_chw, first_frame_pred=True)
280
  else:
281
  out_prob = self.core.step(image=img_chw)
282
- alpha = self._to_alpha(out_prob)
283
- self.started = True
284
- else:
285
- out_prob = self.core.step(image=img_chw)
286
- alpha = self._to_alpha(out_prob)
287
-
288
- if (th, tw) != (H, W):
289
- a_b1hw = _to_b1hw_alpha(alpha, device=img_chw.device)
290
- a_b1hw = torch.nn.functional.interpolate(a_b1hw, size=(H, W), mode="bilinear", align_corners=False)
291
- alpha = a_b1hw[0, 0]
292
- return _to_2d_alpha_numpy(alpha)
293
-
294
- except torch.cuda.OutOfMemoryError as e:
295
- last_exc = e
296
- logger.warning(f"MatAnyone OOM at {th}x{tw}; retrying smaller. {e}")
297
- torch.cuda.empty_cache()
298
- continue
299
- except Exception as e:
300
- last_exc = e
301
- logger.debug(traceback.format_exc())
302
- logger.warning(f"MatAnyone call failed at {th}x{tw}; retrying smaller. {e}")
303
- torch.cuda.empty_cache()
304
- continue
305
-
306
- logger.warning(f"MatAnyone calls failed; returning input mask as fallback. {last_exc}")
307
- if msk_b1hw is not None:
308
- return _to_2d_alpha_numpy(msk_b1hw)
309
- return np.full((H, W), 0.5, dtype=np.float32)
310
- # -------------------------------- Loader ---------------------------------- #
311
-
312
- def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
313
- if device != "cuda":
314
- return torch.float32, False, None
315
- bf16_ok = hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported()
316
- cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
317
- fp16_ok = cc[0] >= 7 # Volta+
318
- if bf16_ok:
319
- return torch.bfloat16, True, torch.bfloat16
320
- if fp16_ok:
321
- return torch.float16, True, torch.float16
322
- return torch.float32, False, None
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  class MatAnyoneLoader:
325
  """
326
- Official MatAnyone loader with stateful, OOM-resilient adapter.
327
  """
328
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
329
  self.device = _select_device(device)
@@ -335,6 +232,7 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyo
335
  self.model_id = "PeiqingYang/MatAnyone"
336
  self.load_time = 0.0
337
 
 
338
  def _import_model_and_core(self):
339
  model_cls = core_cls = None
340
  err_msgs = []
@@ -359,36 +257,40 @@ def _import_model_and_core(self):
359
  except Exception as e:
360
  err_msgs.append(f"core {mod}.{cls}: {e}")
361
  if model_cls is None or core_cls is None:
362
- msg = " | ".join(err_msgs)
363
- raise ImportError(f"Could not import MatAnyone/InferenceCore: {msg}")
364
  return model_cls, core_cls
365
 
366
  def load(self) -> Optional[Any]:
367
  logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
368
- start = time.time()
369
  try:
370
  model_cls, core_cls = self._import_model_and_core()
371
  model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
372
  logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
 
 
373
  self.model = model_cls.from_pretrained(self.model_id)
374
  try:
375
  self.model = self.model.to(self.device).to(model_dtype)
376
  except Exception:
377
  self.model = self.model.to(self.device)
378
  self.model.eval()
 
 
379
  try:
380
  cfg = getattr(self.model, "cfg", None)
381
- if cfg is not None:
382
- self.core = core_cls(self.model, cfg=cfg)
383
- else:
384
- self.core = core_cls(self.model)
385
  except TypeError:
386
  self.core = core_cls(self.model)
 
 
387
  try:
388
  if hasattr(self.core, "to"):
389
  self.core.to(self.device)
390
  except Exception:
391
  pass
 
 
392
  max_edge = int(os.environ.get("MATANYONE_MAX_EDGE", "768"))
393
  target_pixels = int(os.environ.get("MATANYONE_TARGET_PIXELS", "600000"))
394
  self.adapter = _MatAnyoneSession(
@@ -400,9 +302,10 @@ def load(self) -> Optional[Any]:
400
  max_edge=max_edge,
401
  target_pixels=target_pixels,
402
  )
403
- self.load_time = time.time() - start
404
  logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
405
  return self.adapter
 
406
  except Exception as e:
407
  logger.error(f"Failed to load MatAnyone: {e}")
408
  logger.debug(traceback.format_exc())
@@ -430,24 +333,14 @@ def get_info(self) -> Dict[str, Any]:
430
  }
431
 
432
  def debug_shapes(self, image, mask, tag: str = ""):
433
- debug_shapes(tag, image, mask)
434
-
435
- # -------------------------- Optional: Module-level symbols --------------------------
436
-
437
- __all__ = [
438
- "MatAnyoneLoader",
439
- "_MatAnyoneSession",
440
- "_to_bchw",
441
- "_resize_bchw",
442
- "_to_chw_image",
443
- "_to_1hw_mask",
444
- "_to_b1hw_alpha",
445
- "_to_2d_alpha_numpy",
446
- "debug_shapes"
447
- ]
448
-
449
- # -------------------------- (Optional) Simple CLI for quick testing --------------------------
450
-
451
  if __name__ == "__main__":
452
  import sys
453
  import cv2
@@ -457,15 +350,16 @@ def debug_shapes(self, image, mask, tag: str = ""):
457
 
458
  if len(sys.argv) < 2:
459
  print(f"Usage: {sys.argv[0]} image.jpg [mask.png]")
460
- sys.exit(1)
 
461
  image_path = sys.argv[1]
462
- mask_path = sys.argv[2] if len(sys.argv) > 2 else None
463
 
464
- img = cv2.imread(image_path, cv2.IMREAD_COLOR)
465
- if img is None:
466
  print(f"Could not load image {image_path}")
467
- sys.exit(2)
468
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
469
 
470
  mask = None
471
  if mask_path:
@@ -477,7 +371,7 @@ def debug_shapes(self, image, mask, tag: str = ""):
477
  session = loader.load()
478
  if not session:
479
  print("Failed to load MatAnyone")
480
- sys.exit(3)
481
 
482
  alpha = session(img_rgb, mask)
483
  cv2.imwrite("alpha_out.png", (np.clip(alpha, 0, 1) * 255).astype(np.uint8))
 
1
+ from matanyone_loader import MatAnyoneLoader
2
+ import cv2, numpy as np, torch
3
+
4
+ # Load session (stateful per video)
5
+ loader = MatAnyoneLoader(device="cuda")
6
+ session = loader.load()
7
+ assert session, "MatAnyone failed to load"
8
+
9
+ # Frame 0 (must supply a coarse mask, even a fallback like 0.5 or ones)
10
+ bgr0 = cv2.imread("frame0001.jpg")
11
+ rgb0 = cv2.cvtColor(bgr0, cv2.COLOR_BGR2RGB)
12
+ coarse0 = np.ones((rgb0.shape[0], rgb0.shape[1]), dtype=np.float32) # example fallback
13
+
14
+ alpha0 = session(rgb0, coarse0) # -> 2-D float32 [H,W]
15
+
16
+ # Frames 1..N (mask=None, stateful propagation)
17
+ for i in range(2, 6):
18
+ bgr = cv2.imread(f"frame000{i}.jpg")
19
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
20
+ alpha = session(rgb, mask=None) # -> 2-D float32 [H,W]
21
  #!/usr/bin/env python3
22
+ # -*- coding: utf-8 -*-
23
+
24
  """
25
  MatAnyone Loader + Stateful Adapter (OOM-resilient, spatially robust)
26
+ - Canonical HF load (MatAnyone.from_pretrained InferenceCore(model, cfg))
27
  - Mixed precision (bf16/fp16) with safe fallback to fp32
28
+ - torch.autocast(device_type="cuda", dtype=...) + torch.inference_mode()
29
+ - Progressive downscale ladder with graceful fallback
30
+ - Strict image↔mask alignment on every path/scale
31
+ - Returns 2-D float32 [H,W] alpha (OpenCV-friendly)
32
  """
33
 
34
+ from __future__ import annotations
35
+
36
  import os
37
  import time
38
  import logging
39
  import traceback
40
+ from typing import Optional, Dict, Any, Tuple, List
41
 
42
  import numpy as np
43
  import torch
44
  import torch.nn.functional as F
45
  import inspect
46
+ import threading
 
 
 
 
47
  def _select_device(pref: str) -> str:
48
  pref = (pref or "").lower()
49
  if pref.startswith("cuda"):
 
52
  return "cpu"
53
  return "cuda" if torch.cuda.is_available() else "cpu"
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
57
+ """Pick model weight dtype + autocast dtype (bf16>fp16>fp32)."""
58
+ if device != "cuda":
59
+ return torch.float32, False, None
60
+ bf16_ok = hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported()
61
+ cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
62
+ fp16_ok = cc[0] >= 7 # Volta+
63
+ if bf16_ok:
64
+ return torch.bfloat16, True, torch.bfloat16
65
+ if fp16_ok:
66
+ return torch.float16, True, torch.float16
67
+ return torch.float32, False, None
68
  class _MatAnyoneSession:
69
  """
70
  Stateful controller around InferenceCore with OOM-resilient inference.
71
+ First call MUST supply a coarse mask (we enforce 1HW internally).
 
 
 
 
72
  """
73
  def __init__(
74
  self,
 
78
  use_autocast: bool,
79
  autocast_dtype: Optional[torch.dtype],
80
  max_edge: int = 768,
81
+ target_pixels: int = 600_000, # ~775x775 by area
82
  ):
83
  self.core = core
84
  self.device = device
 
88
  self.max_edge = int(max_edge)
89
  self.target_pixels = int(target_pixels)
90
  self.started = False
91
+ self._lock = threading.Lock()
92
 
93
+ # Introspect optional args
94
  try:
95
  sig = inspect.signature(self.core.step)
96
  self._has_first_frame_pred = "first_frame_pred" in sig.parameters
 
99
  self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
100
 
101
  def reset(self):
102
+ with self._lock:
103
+ try:
104
+ if hasattr(self.core, "clear_memory"):
105
+ self.core.clear_memory()
106
+ except Exception:
107
+ pass
108
+ self.started = False
109
 
110
+ def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
111
+ nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
112
+ sizes = [(nh, nw)]
113
+ if s < 1.0:
114
+ f_chain = (0.85, 0.70, 0.55, 0.40)
115
+ cur_h, cur_w = nh, nw
116
+ for f in f_chain:
117
+ cur_h = max(128, int(cur_h * f))
118
+ cur_w = max(128, int(cur_w * f))
119
+ if sizes[-1] != (cur_h, cur_w):
120
+ sizes.append((cur_h, cur_w))
121
+ return sizes
122
 
123
  def _to_alpha(self, out_prob):
124
  if self._has_prob_to_mask:
 
128
  pass
129
  t = torch.as_tensor(out_prob).float()
130
  if t.ndim == 4:
131
+ return t[0, 0] if t.shape[1] >= 1 else t[0].mean(0)
 
 
 
132
  if t.ndim == 3:
133
  return t[0] if t.shape[0] >= 1 else t.mean(0)
134
  return t
135
 
136
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
137
  """
138
+ Returns a 2-D float32 alpha [H,W].
139
+ - frame 0: provide coarse mask → session initialized
140
+ - frames 1..N: pass mask=None (propagation)
141
  """
142
+ with self._lock:
143
+ img_bchw = _to_bchw(image, self.device, is_mask=False) # [1,C,H,W]
144
+ H, W = img_bchw.shape[-2], img_bchw.shape[-1]
145
+ img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
146
+
147
+ # Normalize + align provided mask (if any) to **B1HW** at full res
148
+ msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
149
+ if msk_b1hw is not None and msk_b1hw.shape[-2:] != (H, W):
150
+ msk_b1hw = _resize_bchw(msk_b1hw, (H, W), is_mask=True)
151
+ mask_1hw = _to_1hw_mask(msk_b1hw) if msk_b1hw is not None else None # ← 1HW!
152
+
153
+ sizes = self._scaled_ladder(H, W)
154
+ last_exc = None
155
+
156
+ for (th, tw) in sizes:
157
+ try:
158
+ img_in = img_bchw if (th, tw) == (H, W) else F.interpolate(
159
+ img_bchw, size=(th, tw), mode="bilinear", align_corners=False
160
+ )
161
+ msk_in = None
162
+ if mask_1hw is not None:
163
+ if (th, tw) == (H, W):
164
+ msk_in = mask_1hw
165
+ else:
166
+ msk_in = F.interpolate(mask_1hw.unsqueeze(0), size=(th, tw), mode="nearest")[0]
167
 
168
+ img_chw = _to_chw_image(img_in).contiguous() # [C,H,W]
169
 
170
+ with torch.inference_mode():
171
+ if self.use_autocast:
172
+ amp_ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype)
173
+ else:
174
+ class _NoOp:
175
+ def __enter__(self): return None
176
+ def __exit__(self, *a): return False
177
+ amp_ctx = _NoOp()
178
+
179
+ with amp_ctx:
180
+ if not self.started:
181
+ if msk_in is None:
182
+ # Should not happen when used correctly — still be defensive
183
+ logger.warning("First frame arrived without a mask; returning neutral alpha.")
184
+ return np.full((H, W), 0.5, dtype=np.float32)
185
+ # CRITICAL: pass **1HW** to .step(mask=...)
186
+ _ = self.core.step(image=img_chw, mask=msk_in)
187
+ if self._has_first_frame_pred:
188
+ out_prob = self.core.step(image=img_chw, first_frame_pred=True)
189
+ else:
190
+ out_prob = self.core.step(image=img_chw)
191
+ self.started = True
 
 
 
 
 
 
 
192
  else:
193
  out_prob = self.core.step(image=img_chw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ alpha = self._to_alpha(out_prob)
196
+
197
+ # Upsample alpha back if we ran at a smaller scale
198
+ if (th, tw) != (H, W):
199
+ a_b1hw = _to_b1hw_alpha(alpha, device=img_bchw.device)
200
+ a_b1hw = F.interpolate(a_b1hw, size=(H, W), mode="bilinear", align_corners=False)
201
+ alpha = a_b1hw[0, 0]
202
+
203
+ return _to_2d_alpha_numpy(alpha)
204
+
205
+ except torch.cuda.OutOfMemoryError as e:
206
+ last_exc = e
207
+ torch.cuda.empty_cache()
208
+ logger.warning(f"MatAnyone OOM at {th}x{tw}; retrying smaller. {e}")
209
+ continue
210
+ except Exception as e:
211
+ last_exc = e
212
+ torch.cuda.empty_cache()
213
+ logger.debug(traceback.format_exc())
214
+ logger.warning(f"MatAnyone call failed at {th}x{tw}; retrying smaller. {e}")
215
+ continue
216
+
217
+ logger.warning(f"MatAnyone calls failed; returning input mask or neutral alpha. {last_exc}")
218
+ if mask_1hw is not None:
219
+ return _to_2d_alpha_numpy(mask_1hw)
220
+ return np.full((H, W), 0.5, dtype=np.float32)
221
  class MatAnyoneLoader:
222
  """
223
+ Official MatAnyone loader with stateful, OOM-resilient session adapter.
224
  """
225
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
226
  self.device = _select_device(device)
 
232
  self.model_id = "PeiqingYang/MatAnyone"
233
  self.load_time = 0.0
234
 
235
+ # --- Robust imports (works with different packaging layouts) ---
236
  def _import_model_and_core(self):
237
  model_cls = core_cls = None
238
  err_msgs = []
 
257
  except Exception as e:
258
  err_msgs.append(f"core {mod}.{cls}: {e}")
259
  if model_cls is None or core_cls is None:
260
+ raise ImportError("Could not import MatAnyone / InferenceCore: " + " | ".join(err_msgs))
 
261
  return model_cls, core_cls
262
 
263
  def load(self) -> Optional[Any]:
264
  logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
265
+ t0 = time.time()
266
  try:
267
  model_cls, core_cls = self._import_model_and_core()
268
  model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
269
  logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
270
+
271
+ # HF weights (safetensors); keep trust defaults of library itself
272
  self.model = model_cls.from_pretrained(self.model_id)
273
  try:
274
  self.model = self.model.to(self.device).to(model_dtype)
275
  except Exception:
276
  self.model = self.model.to(self.device)
277
  self.model.eval()
278
+
279
+ # Inference core (cfg may or may not exist on the model)
280
  try:
281
  cfg = getattr(self.model, "cfg", None)
282
+ self.core = core_cls(self.model, cfg=cfg) if cfg is not None else core_cls(self.model)
 
 
 
283
  except TypeError:
284
  self.core = core_cls(self.model)
285
+
286
+ # Some versions expose .to(), some don’t — best effort
287
  try:
288
  if hasattr(self.core, "to"):
289
  self.core.to(self.device)
290
  except Exception:
291
  pass
292
+
293
+ # Build stateful adapter
294
  max_edge = int(os.environ.get("MATANYONE_MAX_EDGE", "768"))
295
  target_pixels = int(os.environ.get("MATANYONE_TARGET_PIXELS", "600000"))
296
  self.adapter = _MatAnyoneSession(
 
302
  max_edge=max_edge,
303
  target_pixels=target_pixels,
304
  )
305
+ self.load_time = time.time() - t0
306
  logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
307
  return self.adapter
308
+
309
  except Exception as e:
310
  logger.error(f"Failed to load MatAnyone: {e}")
311
  logger.debug(traceback.format_exc())
 
333
  }
334
 
335
  def debug_shapes(self, image, mask, tag: str = ""):
336
+ try:
337
+ tv_img = torch.as_tensor(image)
338
+ tv_msk = torch.as_tensor(mask) if mask is not None else None
339
+ logger.info(f"[{tag}:image] shape={tuple(tv_img.shape)} dtype={tv_img.dtype}")
340
+ if tv_msk is not None:
341
+ logger.info(f"[{tag}:mask ] shape={tuple(tv_msk.shape)} dtype={tv_msk.dtype}")
342
+ except Exception as e:
343
+ logger.info(f"[{tag}] debug error: {e}")
 
 
 
 
 
 
 
 
 
 
344
  if __name__ == "__main__":
345
  import sys
346
  import cv2
 
350
 
351
  if len(sys.argv) < 2:
352
  print(f"Usage: {sys.argv[0]} image.jpg [mask.png]")
353
+ raise SystemExit(1)
354
+
355
  image_path = sys.argv[1]
356
+ mask_path = sys.argv[2] if len(sys.argv) > 2 else None
357
 
358
+ img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
359
+ if img_bgr is None:
360
  print(f"Could not load image {image_path}")
361
+ raise SystemExit(2)
362
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
363
 
364
  mask = None
365
  if mask_path:
 
371
  session = loader.load()
372
  if not session:
373
  print("Failed to load MatAnyone")
374
+ raise SystemExit(3)
375
 
376
  alpha = session(img_rgb, mask)
377
  cv2.imwrite("alpha_out.png", (np.clip(alpha, 0, 1) * 255).astype(np.uint8))