Update models/loaders/matanyone_loader.py
Browse files
models/loaders/matanyone_loader.py
CHANGED
|
@@ -11,6 +11,8 @@
|
|
| 11 |
- Added: Force chunk_size=1, flip_aug=False in cfg to avoid dim mismatches
|
| 12 |
- Added: Pad to multiple of 16 to avoid transformer patch issues
|
| 13 |
- Added: Prefer fp16 over bf16 for Tesla T4 compatibility
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
from __future__ import annotations
|
| 16 |
import os
|
|
@@ -24,6 +26,34 @@
|
|
| 24 |
import inspect
|
| 25 |
import threading
|
| 26 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# ---------------------------------------------------------------------------
|
| 28 |
# Utilities (shapes, dtype, scaling)
|
| 29 |
# ---------------------------------------------------------------------------
|
|
@@ -34,10 +64,12 @@ def _select_device(pref: str) -> str:
|
|
| 34 |
if pref == "cpu":
|
| 35 |
return "cpu"
|
| 36 |
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 37 |
def _as_tensor_on_device(x, device: str) -> torch.Tensor:
|
| 38 |
if isinstance(x, torch.Tensor):
|
| 39 |
return x.to(device, non_blocking=True)
|
| 40 |
return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
|
|
|
|
| 41 |
def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
|
| 42 |
"""
|
| 43 |
Normalize input to BCHW (image) or B1HW (mask).
|
|
@@ -72,10 +104,12 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
|
|
| 72 |
x = x.repeat(1, 3, 1, 1)
|
| 73 |
x = x.clamp_(0.0, 1.0)
|
| 74 |
return x
|
|
|
|
| 75 |
def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
|
| 76 |
if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
|
| 77 |
return img_bchw[0]
|
| 78 |
return img_bchw
|
|
|
|
| 79 |
def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
|
| 80 |
if msk_b1hw is None:
|
| 81 |
return None
|
|
@@ -84,6 +118,7 @@ def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
|
|
| 84 |
if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
|
| 85 |
return msk_b1hw
|
| 86 |
raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
|
|
|
|
| 87 |
def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: bool = False) -> Optional[torch.Tensor]:
|
| 88 |
if x is None:
|
| 89 |
return None
|
|
@@ -91,6 +126,7 @@ def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: b
|
|
| 91 |
return x
|
| 92 |
mode = "nearest" if is_mask else "bilinear"
|
| 93 |
return F.interpolate(x, size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
|
|
|
|
| 94 |
def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
|
| 95 |
t = torch.as_tensor(alpha, device=device).float()
|
| 96 |
if t.ndim == 2:
|
|
@@ -117,6 +153,7 @@ def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
|
|
| 117 |
if t.shape[1] != 1:
|
| 118 |
t = t[:, :1]
|
| 119 |
return t.clamp_(0.0, 1.0).contiguous()
|
|
|
|
| 120 |
def _to_2d_alpha_numpy(x) -> np.ndarray:
|
| 121 |
t = torch.as_tensor(x).float()
|
| 122 |
while t.ndim > 2:
|
|
@@ -129,6 +166,7 @@ def _to_2d_alpha_numpy(x) -> np.ndarray:
|
|
| 129 |
t = t.clamp_(0.0, 1.0)
|
| 130 |
out = t.detach().cpu().numpy().astype(np.float32)
|
| 131 |
return np.ascontiguousarray(out)
|
|
|
|
| 132 |
def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
|
| 133 |
if h <= 0 or w <= 0:
|
| 134 |
return h, w, 1.0
|
|
@@ -138,6 +176,7 @@ def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> T
|
|
| 138 |
nh = max(128, int(round(h * s))) # Force min 128 to avoid small-res bugs
|
| 139 |
nw = max(128, int(round(w * s)))
|
| 140 |
return nh, nw, s
|
|
|
|
| 141 |
def _pad_to_multiple(t: Optional[torch.Tensor], multiple: int = 16) -> Optional[torch.Tensor]:
|
| 142 |
if t is None:
|
| 143 |
return None
|
|
@@ -155,6 +194,7 @@ def _pad_to_multiple(t: Optional[torch.Tensor], multiple: int = 16) -> Optional[
|
|
| 155 |
if t.ndim == 2: # Shouldn't happen
|
| 156 |
t = t.squeeze(0)
|
| 157 |
return t
|
|
|
|
| 158 |
def debug_shapes(tag: str, image, mask) -> None:
|
| 159 |
def _info(name, v):
|
| 160 |
try:
|
|
@@ -166,6 +206,7 @@ def _info(name, v):
|
|
| 166 |
logger.info(f"[{tag}:{name}] type={type(v)} err={e}")
|
| 167 |
_info("image", image)
|
| 168 |
_info("mask", mask)
|
|
|
|
| 169 |
# ---------------------------------------------------------------------------
|
| 170 |
# Precision selection
|
| 171 |
# ---------------------------------------------------------------------------
|
|
@@ -181,6 +222,7 @@ def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dt
|
|
| 181 |
if bf16_ok:
|
| 182 |
return torch.bfloat16, True, torch.bfloat16
|
| 183 |
return torch.float32, False, None
|
|
|
|
| 184 |
# ---------------------------------------------------------------------------
|
| 185 |
# Stateful Adapter around InferenceCore
|
| 186 |
# ---------------------------------------------------------------------------
|
|
@@ -215,6 +257,7 @@ def __init__(
|
|
| 215 |
except Exception:
|
| 216 |
self._has_first_frame_pred = True
|
| 217 |
self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
|
|
|
|
| 218 |
def reset(self):
|
| 219 |
with self._lock:
|
| 220 |
try:
|
|
@@ -223,6 +266,7 @@ def reset(self):
|
|
| 223 |
except Exception:
|
| 224 |
pass
|
| 225 |
self.started = False
|
|
|
|
| 226 |
def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
|
| 227 |
nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
|
| 228 |
sizes = [(nh, nw)]
|
|
@@ -235,6 +279,7 @@ def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
|
|
| 235 |
if sizes[-1] != (cur_h, cur_w):
|
| 236 |
sizes.append((cur_h, cur_w))
|
| 237 |
return sizes
|
|
|
|
| 238 |
def _to_alpha(self, out_prob):
|
| 239 |
if self._has_prob_to_mask:
|
| 240 |
try:
|
|
@@ -247,6 +292,7 @@ def _to_alpha(self, out_prob):
|
|
| 247 |
if t.ndim == 3:
|
| 248 |
return t[0] if t.shape[0] >= 1 else t.mean(0)
|
| 249 |
return t
|
|
|
|
| 250 |
def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
| 251 |
"""
|
| 252 |
Returns a 2-D float32 alpha [H,W].
|
|
@@ -329,6 +375,7 @@ def __exit__(self, *a): return False
|
|
| 329 |
if mask_1hw is not None:
|
| 330 |
return _to_2d_alpha_numpy(mask_1hw)
|
| 331 |
return np.full((H, W), 0.5, dtype=np.float32)
|
|
|
|
| 332 |
# ---------------------------------------------------------------------------
|
| 333 |
# Loader
|
| 334 |
# ---------------------------------------------------------------------------
|
|
@@ -345,6 +392,7 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyo
|
|
| 345 |
self.adapter = None
|
| 346 |
self.model_id = "PeiqingYang/MatAnyone"
|
| 347 |
self.load_time = 0.0
|
|
|
|
| 348 |
# --- Robust imports (works with different packaging layouts) ---
|
| 349 |
def _import_model_and_core(self):
|
| 350 |
model_cls = core_cls = None
|
|
@@ -372,6 +420,7 @@ def _import_model_and_core(self):
|
|
| 372 |
if model_cls is None or core_cls is None:
|
| 373 |
raise ImportError("Could not import MatAnyone / InferenceCore: " + " | ".join(err_msgs))
|
| 374 |
return model_cls, core_cls
|
|
|
|
| 375 |
def load(self) -> Optional[Any]:
|
| 376 |
logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
|
| 377 |
t0 = time.time()
|
|
@@ -386,16 +435,111 @@ def load(self) -> Optional[Any]:
|
|
| 386 |
except Exception:
|
| 387 |
self.model = self.model.to(self.device)
|
| 388 |
self.model.eval()
|
| 389 |
-
#
|
| 390 |
default_cfg = {
|
| 391 |
-
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
}
|
|
|
|
| 394 |
cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
|
| 395 |
if isinstance(cfg, dict):
|
| 396 |
-
cfg
|
| 397 |
-
|
| 398 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
# Inference core
|
| 400 |
try:
|
| 401 |
self.core = core_cls(self.model, cfg=cfg)
|
|
@@ -426,6 +570,7 @@ def load(self) -> Optional[Any]:
|
|
| 426 |
logger.error(f"Failed to load MatAnyone: {e}")
|
| 427 |
logger.debug(traceback.format_exc())
|
| 428 |
return None
|
|
|
|
| 429 |
def cleanup(self):
|
| 430 |
self.adapter = None
|
| 431 |
self.core = None
|
|
@@ -437,6 +582,7 @@ def cleanup(self):
|
|
| 437 |
self.model = None
|
| 438 |
if torch.cuda.is_available():
|
| 439 |
torch.cuda.empty_cache()
|
|
|
|
| 440 |
def get_info(self) -> Dict[str, Any]:
|
| 441 |
return {
|
| 442 |
"loaded": self.adapter is not None,
|
|
@@ -445,6 +591,7 @@ def get_info(self) -> Dict[str, Any]:
|
|
| 445 |
"load_time": self.load_time,
|
| 446 |
"model_type": type(self.model).__name__ if self.model else None,
|
| 447 |
}
|
|
|
|
| 448 |
def debug_shapes(self, image, mask, tag: str = ""):
|
| 449 |
try:
|
| 450 |
tv_img = torch.as_tensor(image)
|
|
@@ -454,6 +601,7 @@ def debug_shapes(self, image, mask, tag: str = ""):
|
|
| 454 |
logger.info(f"[{tag}:mask ] shape={tuple(tv_msk.shape)} dtype={tv_msk.dtype}")
|
| 455 |
except Exception as e:
|
| 456 |
logger.info(f"[{tag}] debug error: {e}")
|
|
|
|
| 457 |
# ---------------------------------------------------------------------------
|
| 458 |
# Public symbols
|
| 459 |
# ---------------------------------------------------------------------------
|
|
@@ -469,6 +617,7 @@ def debug_shapes(self, image, mask, tag: str = ""):
|
|
| 469 |
"_compute_scaled_size",
|
| 470 |
"debug_shapes",
|
| 471 |
]
|
|
|
|
| 472 |
# ---------------------------------------------------------------------------
|
| 473 |
# Optional CLI for quick testing (no circular imports)
|
| 474 |
# ---------------------------------------------------------------------------
|
|
|
|
| 11 |
- Added: Force chunk_size=1, flip_aug=False in cfg to avoid dim mismatches
|
| 12 |
- Added: Pad to multiple of 16 to avoid transformer patch issues
|
| 13 |
- Added: Prefer fp16 over bf16 for Tesla T4 compatibility
|
| 14 |
+
- New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
|
| 15 |
+
- New: Full default cfg from official config.json to ensure keys like mem_every are present
|
| 16 |
"""
|
| 17 |
from __future__ import annotations
|
| 18 |
import os
|
|
|
|
| 26 |
import inspect
|
| 27 |
import threading
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
# EasyDict polyfill (recursive dict with dot access)
|
| 31 |
+
class EasyDict(dict):
|
| 32 |
+
def __init__(self, d=None, **kwargs):
|
| 33 |
+
if d is None:
|
| 34 |
+
d = {}
|
| 35 |
+
if kwargs:
|
| 36 |
+
d.update(**kwargs)
|
| 37 |
+
for k, v in d.items():
|
| 38 |
+
if isinstance(v, dict):
|
| 39 |
+
self[k] = EasyDict(v)
|
| 40 |
+
elif isinstance(v, list):
|
| 41 |
+
self[k] = [EasyDict(i) if isinstance(i, dict) else i for i in v]
|
| 42 |
+
else:
|
| 43 |
+
self[k] = v
|
| 44 |
+
|
| 45 |
+
def __getattr__(self, name):
|
| 46 |
+
try:
|
| 47 |
+
return self[name]
|
| 48 |
+
except KeyError:
|
| 49 |
+
raise AttributeError(name)
|
| 50 |
+
|
| 51 |
+
def __setattr__(self, name, value):
|
| 52 |
+
self[name] = value
|
| 53 |
+
|
| 54 |
+
def __delattr__(self, name):
|
| 55 |
+
del self[name]
|
| 56 |
+
|
| 57 |
# ---------------------------------------------------------------------------
|
| 58 |
# Utilities (shapes, dtype, scaling)
|
| 59 |
# ---------------------------------------------------------------------------
|
|
|
|
| 64 |
if pref == "cpu":
|
| 65 |
return "cpu"
|
| 66 |
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 67 |
+
|
| 68 |
def _as_tensor_on_device(x, device: str) -> torch.Tensor:
|
| 69 |
if isinstance(x, torch.Tensor):
|
| 70 |
return x.to(device, non_blocking=True)
|
| 71 |
return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
|
| 72 |
+
|
| 73 |
def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
|
| 74 |
"""
|
| 75 |
Normalize input to BCHW (image) or B1HW (mask).
|
|
|
|
| 104 |
x = x.repeat(1, 3, 1, 1)
|
| 105 |
x = x.clamp_(0.0, 1.0)
|
| 106 |
return x
|
| 107 |
+
|
| 108 |
def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
|
| 109 |
if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
|
| 110 |
return img_bchw[0]
|
| 111 |
return img_bchw
|
| 112 |
+
|
| 113 |
def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
|
| 114 |
if msk_b1hw is None:
|
| 115 |
return None
|
|
|
|
| 118 |
if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
|
| 119 |
return msk_b1hw
|
| 120 |
raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
|
| 121 |
+
|
| 122 |
def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: bool = False) -> Optional[torch.Tensor]:
|
| 123 |
if x is None:
|
| 124 |
return None
|
|
|
|
| 126 |
return x
|
| 127 |
mode = "nearest" if is_mask else "bilinear"
|
| 128 |
return F.interpolate(x, size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
|
| 129 |
+
|
| 130 |
def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
|
| 131 |
t = torch.as_tensor(alpha, device=device).float()
|
| 132 |
if t.ndim == 2:
|
|
|
|
| 153 |
if t.shape[1] != 1:
|
| 154 |
t = t[:, :1]
|
| 155 |
return t.clamp_(0.0, 1.0).contiguous()
|
| 156 |
+
|
| 157 |
def _to_2d_alpha_numpy(x) -> np.ndarray:
|
| 158 |
t = torch.as_tensor(x).float()
|
| 159 |
while t.ndim > 2:
|
|
|
|
| 166 |
t = t.clamp_(0.0, 1.0)
|
| 167 |
out = t.detach().cpu().numpy().astype(np.float32)
|
| 168 |
return np.ascontiguousarray(out)
|
| 169 |
+
|
| 170 |
def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
|
| 171 |
if h <= 0 or w <= 0:
|
| 172 |
return h, w, 1.0
|
|
|
|
| 176 |
nh = max(128, int(round(h * s))) # Force min 128 to avoid small-res bugs
|
| 177 |
nw = max(128, int(round(w * s)))
|
| 178 |
return nh, nw, s
|
| 179 |
+
|
| 180 |
def _pad_to_multiple(t: Optional[torch.Tensor], multiple: int = 16) -> Optional[torch.Tensor]:
|
| 181 |
if t is None:
|
| 182 |
return None
|
|
|
|
| 194 |
if t.ndim == 2: # Shouldn't happen
|
| 195 |
t = t.squeeze(0)
|
| 196 |
return t
|
| 197 |
+
|
| 198 |
def debug_shapes(tag: str, image, mask) -> None:
|
| 199 |
def _info(name, v):
|
| 200 |
try:
|
|
|
|
| 206 |
logger.info(f"[{tag}:{name}] type={type(v)} err={e}")
|
| 207 |
_info("image", image)
|
| 208 |
_info("mask", mask)
|
| 209 |
+
|
| 210 |
# ---------------------------------------------------------------------------
|
| 211 |
# Precision selection
|
| 212 |
# ---------------------------------------------------------------------------
|
|
|
|
| 222 |
if bf16_ok:
|
| 223 |
return torch.bfloat16, True, torch.bfloat16
|
| 224 |
return torch.float32, False, None
|
| 225 |
+
|
| 226 |
# ---------------------------------------------------------------------------
|
| 227 |
# Stateful Adapter around InferenceCore
|
| 228 |
# ---------------------------------------------------------------------------
|
|
|
|
| 257 |
except Exception:
|
| 258 |
self._has_first_frame_pred = True
|
| 259 |
self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
|
| 260 |
+
|
| 261 |
def reset(self):
|
| 262 |
with self._lock:
|
| 263 |
try:
|
|
|
|
| 266 |
except Exception:
|
| 267 |
pass
|
| 268 |
self.started = False
|
| 269 |
+
|
| 270 |
def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
|
| 271 |
nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
|
| 272 |
sizes = [(nh, nw)]
|
|
|
|
| 279 |
if sizes[-1] != (cur_h, cur_w):
|
| 280 |
sizes.append((cur_h, cur_w))
|
| 281 |
return sizes
|
| 282 |
+
|
| 283 |
def _to_alpha(self, out_prob):
|
| 284 |
if self._has_prob_to_mask:
|
| 285 |
try:
|
|
|
|
| 292 |
if t.ndim == 3:
|
| 293 |
return t[0] if t.shape[0] >= 1 else t.mean(0)
|
| 294 |
return t
|
| 295 |
+
|
| 296 |
def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
|
| 297 |
"""
|
| 298 |
Returns a 2-D float32 alpha [H,W].
|
|
|
|
| 375 |
if mask_1hw is not None:
|
| 376 |
return _to_2d_alpha_numpy(mask_1hw)
|
| 377 |
return np.full((H, W), 0.5, dtype=np.float32)
|
| 378 |
+
|
| 379 |
# ---------------------------------------------------------------------------
|
| 380 |
# Loader
|
| 381 |
# ---------------------------------------------------------------------------
|
|
|
|
| 392 |
self.adapter = None
|
| 393 |
self.model_id = "PeiqingYang/MatAnyone"
|
| 394 |
self.load_time = 0.0
|
| 395 |
+
|
| 396 |
# --- Robust imports (works with different packaging layouts) ---
|
| 397 |
def _import_model_and_core(self):
|
| 398 |
model_cls = core_cls = None
|
|
|
|
| 420 |
if model_cls is None or core_cls is None:
|
| 421 |
raise ImportError("Could not import MatAnyone / InferenceCore: " + " | ".join(err_msgs))
|
| 422 |
return model_cls, core_cls
|
| 423 |
+
|
| 424 |
def load(self) -> Optional[Any]:
|
| 425 |
logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
|
| 426 |
t0 = time.time()
|
|
|
|
| 435 |
except Exception:
|
| 436 |
self.model = self.model.to(self.device)
|
| 437 |
self.model.eval()
|
| 438 |
+
# Full default cfg from official config.json
|
| 439 |
default_cfg = {
|
| 440 |
+
"amp": False,
|
| 441 |
+
"chunk_size": -1,
|
| 442 |
+
"flip_aug": False,
|
| 443 |
+
"long_term": {
|
| 444 |
+
"buffer_tokens": 2000,
|
| 445 |
+
"count_usage": True,
|
| 446 |
+
"max_mem_frames": 10,
|
| 447 |
+
"max_num_tokens": 10000,
|
| 448 |
+
"min_mem_frames": 5,
|
| 449 |
+
"num_prototypes": 128
|
| 450 |
+
},
|
| 451 |
+
"max_internal_size": -1,
|
| 452 |
+
"max_mem_frames": 5,
|
| 453 |
+
"mem_every": 5,
|
| 454 |
+
"model": {
|
| 455 |
+
"aux_loss": {
|
| 456 |
+
"query": {
|
| 457 |
+
"enabled": True,
|
| 458 |
+
"weight": 0.01
|
| 459 |
+
},
|
| 460 |
+
"sensory": {
|
| 461 |
+
"enabled": True,
|
| 462 |
+
"weight": 0.01
|
| 463 |
+
}
|
| 464 |
+
},
|
| 465 |
+
"embed_dim": 256,
|
| 466 |
+
"key_dim": 64,
|
| 467 |
+
"mask_decoder": {
|
| 468 |
+
"up_dims": [256, 128, 128, 64, 16]
|
| 469 |
+
},
|
| 470 |
+
"mask_encoder": {
|
| 471 |
+
"final_dim": 256,
|
| 472 |
+
"type": "resnet18"
|
| 473 |
+
},
|
| 474 |
+
"object_summarizer": {
|
| 475 |
+
"add_pe": True,
|
| 476 |
+
"embed_dim": 256,
|
| 477 |
+
"num_summaries": 16
|
| 478 |
+
},
|
| 479 |
+
"object_transformer": {
|
| 480 |
+
"embed_dim": 256,
|
| 481 |
+
"ff_dim": 2048,
|
| 482 |
+
"num_blocks": 3,
|
| 483 |
+
"num_heads": 8,
|
| 484 |
+
"num_queries": 16,
|
| 485 |
+
"pixel_self_attention": {
|
| 486 |
+
"add_pe_to_qkv": [True, True, False]
|
| 487 |
+
},
|
| 488 |
+
"query_self_attention": {
|
| 489 |
+
"add_pe_to_qkv": [True, True, False]
|
| 490 |
+
},
|
| 491 |
+
"read_from_memory": {
|
| 492 |
+
"add_pe_to_qkv": [True, True, False]
|
| 493 |
+
},
|
| 494 |
+
"read_from_past": {
|
| 495 |
+
"add_pe_to_qkv": [True, True, False]
|
| 496 |
+
},
|
| 497 |
+
"read_from_pixel": {
|
| 498 |
+
"add_pe_to_qkv": [True, True, False],
|
| 499 |
+
"input_add_pe": False,
|
| 500 |
+
"input_norm": False
|
| 501 |
+
},
|
| 502 |
+
"read_from_query": {
|
| 503 |
+
"add_pe_to_qkv": [True, True, False],
|
| 504 |
+
"output_norm": False
|
| 505 |
+
}
|
| 506 |
+
},
|
| 507 |
+
"pixel_dim": 256,
|
| 508 |
+
"pixel_encoder": {
|
| 509 |
+
"ms_dims": [1024, 512, 256, 64, 3],
|
| 510 |
+
"type": "resnet50"
|
| 511 |
+
},
|
| 512 |
+
"pixel_mean": [0.485, 0.456, 0.406],
|
| 513 |
+
"pixel_pe_scale": 32,
|
| 514 |
+
"pixel_pe_temperature": 128,
|
| 515 |
+
"pixel_std": [0.229, 0.224, 0.225],
|
| 516 |
+
"pretrained_resnet": False,
|
| 517 |
+
"sensory_dim": 256,
|
| 518 |
+
"value_dim": 256
|
| 519 |
+
},
|
| 520 |
+
"output_dir": None,
|
| 521 |
+
"save_all": True,
|
| 522 |
+
"save_aux": False,
|
| 523 |
+
"save_scores": False,
|
| 524 |
+
"stagger_updates": 5,
|
| 525 |
+
"top_k": 30,
|
| 526 |
+
"use_all_masks": False,
|
| 527 |
+
"use_long_term": False,
|
| 528 |
+
"visualize": False,
|
| 529 |
+
"weights": "pretrained_models/matanyone.pth"
|
| 530 |
}
|
| 531 |
+
# Get cfg from model if available, else default
|
| 532 |
cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
|
| 533 |
if isinstance(cfg, dict):
|
| 534 |
+
cfg = dict(cfg) # Copy to avoid modifying model.cfg
|
| 535 |
+
# Override specific values
|
| 536 |
+
overrides = {
|
| 537 |
+
'chunk_size': 1,
|
| 538 |
+
'flip_aug': False,
|
| 539 |
+
}
|
| 540 |
+
cfg.update(overrides)
|
| 541 |
+
# Convert to EasyDict for dot access
|
| 542 |
+
cfg = EasyDict(cfg)
|
| 543 |
# Inference core
|
| 544 |
try:
|
| 545 |
self.core = core_cls(self.model, cfg=cfg)
|
|
|
|
| 570 |
logger.error(f"Failed to load MatAnyone: {e}")
|
| 571 |
logger.debug(traceback.format_exc())
|
| 572 |
return None
|
| 573 |
+
|
| 574 |
def cleanup(self):
|
| 575 |
self.adapter = None
|
| 576 |
self.core = None
|
|
|
|
| 582 |
self.model = None
|
| 583 |
if torch.cuda.is_available():
|
| 584 |
torch.cuda.empty_cache()
|
| 585 |
+
|
| 586 |
def get_info(self) -> Dict[str, Any]:
|
| 587 |
return {
|
| 588 |
"loaded": self.adapter is not None,
|
|
|
|
| 591 |
"load_time": self.load_time,
|
| 592 |
"model_type": type(self.model).__name__ if self.model else None,
|
| 593 |
}
|
| 594 |
+
|
| 595 |
def debug_shapes(self, image, mask, tag: str = ""):
|
| 596 |
try:
|
| 597 |
tv_img = torch.as_tensor(image)
|
|
|
|
| 601 |
logger.info(f"[{tag}:mask ] shape={tuple(tv_msk.shape)} dtype={tv_msk.dtype}")
|
| 602 |
except Exception as e:
|
| 603 |
logger.info(f"[{tag}] debug error: {e}")
|
| 604 |
+
|
| 605 |
# ---------------------------------------------------------------------------
|
| 606 |
# Public symbols
|
| 607 |
# ---------------------------------------------------------------------------
|
|
|
|
| 617 |
"_compute_scaled_size",
|
| 618 |
"debug_shapes",
|
| 619 |
]
|
| 620 |
+
|
| 621 |
# ---------------------------------------------------------------------------
|
| 622 |
# Optional CLI for quick testing (no circular imports)
|
| 623 |
# ---------------------------------------------------------------------------
|