MogensR commited on
Commit
2fe1a8c
·
1 Parent(s): 183c1c8

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +155 -6
models/loaders/matanyone_loader.py CHANGED
@@ -11,6 +11,8 @@
11
  - Added: Force chunk_size=1, flip_aug=False in cfg to avoid dim mismatches
12
  - Added: Pad to multiple of 16 to avoid transformer patch issues
13
  - Added: Prefer fp16 over bf16 for Tesla T4 compatibility
 
 
14
  """
15
  from __future__ import annotations
16
  import os
@@ -24,6 +26,34 @@
24
  import inspect
25
  import threading
26
  logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # ---------------------------------------------------------------------------
28
  # Utilities (shapes, dtype, scaling)
29
  # ---------------------------------------------------------------------------
@@ -34,10 +64,12 @@ def _select_device(pref: str) -> str:
34
  if pref == "cpu":
35
  return "cpu"
36
  return "cuda" if torch.cuda.is_available() else "cpu"
 
37
  def _as_tensor_on_device(x, device: str) -> torch.Tensor:
38
  if isinstance(x, torch.Tensor):
39
  return x.to(device, non_blocking=True)
40
  return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
 
41
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
42
  """
43
  Normalize input to BCHW (image) or B1HW (mask).
@@ -72,10 +104,12 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
72
  x = x.repeat(1, 3, 1, 1)
73
  x = x.clamp_(0.0, 1.0)
74
  return x
 
75
  def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
76
  if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
77
  return img_bchw[0]
78
  return img_bchw
 
79
  def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
80
  if msk_b1hw is None:
81
  return None
@@ -84,6 +118,7 @@ def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
84
  if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
85
  return msk_b1hw
86
  raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
 
87
  def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: bool = False) -> Optional[torch.Tensor]:
88
  if x is None:
89
  return None
@@ -91,6 +126,7 @@ def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: b
91
  return x
92
  mode = "nearest" if is_mask else "bilinear"
93
  return F.interpolate(x, size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
 
94
  def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
95
  t = torch.as_tensor(alpha, device=device).float()
96
  if t.ndim == 2:
@@ -117,6 +153,7 @@ def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
117
  if t.shape[1] != 1:
118
  t = t[:, :1]
119
  return t.clamp_(0.0, 1.0).contiguous()
 
120
  def _to_2d_alpha_numpy(x) -> np.ndarray:
121
  t = torch.as_tensor(x).float()
122
  while t.ndim > 2:
@@ -129,6 +166,7 @@ def _to_2d_alpha_numpy(x) -> np.ndarray:
129
  t = t.clamp_(0.0, 1.0)
130
  out = t.detach().cpu().numpy().astype(np.float32)
131
  return np.ascontiguousarray(out)
 
132
  def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
133
  if h <= 0 or w <= 0:
134
  return h, w, 1.0
@@ -138,6 +176,7 @@ def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> T
138
  nh = max(128, int(round(h * s))) # Force min 128 to avoid small-res bugs
139
  nw = max(128, int(round(w * s)))
140
  return nh, nw, s
 
141
  def _pad_to_multiple(t: Optional[torch.Tensor], multiple: int = 16) -> Optional[torch.Tensor]:
142
  if t is None:
143
  return None
@@ -155,6 +194,7 @@ def _pad_to_multiple(t: Optional[torch.Tensor], multiple: int = 16) -> Optional[
155
  if t.ndim == 2: # Shouldn't happen
156
  t = t.squeeze(0)
157
  return t
 
158
  def debug_shapes(tag: str, image, mask) -> None:
159
  def _info(name, v):
160
  try:
@@ -166,6 +206,7 @@ def _info(name, v):
166
  logger.info(f"[{tag}:{name}] type={type(v)} err={e}")
167
  _info("image", image)
168
  _info("mask", mask)
 
169
  # ---------------------------------------------------------------------------
170
  # Precision selection
171
  # ---------------------------------------------------------------------------
@@ -181,6 +222,7 @@ def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dt
181
  if bf16_ok:
182
  return torch.bfloat16, True, torch.bfloat16
183
  return torch.float32, False, None
 
184
  # ---------------------------------------------------------------------------
185
  # Stateful Adapter around InferenceCore
186
  # ---------------------------------------------------------------------------
@@ -215,6 +257,7 @@ def __init__(
215
  except Exception:
216
  self._has_first_frame_pred = True
217
  self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
 
218
  def reset(self):
219
  with self._lock:
220
  try:
@@ -223,6 +266,7 @@ def reset(self):
223
  except Exception:
224
  pass
225
  self.started = False
 
226
  def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
227
  nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
228
  sizes = [(nh, nw)]
@@ -235,6 +279,7 @@ def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
235
  if sizes[-1] != (cur_h, cur_w):
236
  sizes.append((cur_h, cur_w))
237
  return sizes
 
238
  def _to_alpha(self, out_prob):
239
  if self._has_prob_to_mask:
240
  try:
@@ -247,6 +292,7 @@ def _to_alpha(self, out_prob):
247
  if t.ndim == 3:
248
  return t[0] if t.shape[0] >= 1 else t.mean(0)
249
  return t
 
250
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
251
  """
252
  Returns a 2-D float32 alpha [H,W].
@@ -329,6 +375,7 @@ def __exit__(self, *a): return False
329
  if mask_1hw is not None:
330
  return _to_2d_alpha_numpy(mask_1hw)
331
  return np.full((H, W), 0.5, dtype=np.float32)
 
332
  # ---------------------------------------------------------------------------
333
  # Loader
334
  # ---------------------------------------------------------------------------
@@ -345,6 +392,7 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyo
345
  self.adapter = None
346
  self.model_id = "PeiqingYang/MatAnyone"
347
  self.load_time = 0.0
 
348
  # --- Robust imports (works with different packaging layouts) ---
349
  def _import_model_and_core(self):
350
  model_cls = core_cls = None
@@ -372,6 +420,7 @@ def _import_model_and_core(self):
372
  if model_cls is None or core_cls is None:
373
  raise ImportError("Could not import MatAnyone / InferenceCore: " + " | ".join(err_msgs))
374
  return model_cls, core_cls
 
375
  def load(self) -> Optional[Any]:
376
  logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
377
  t0 = time.time()
@@ -386,16 +435,111 @@ def load(self) -> Optional[Any]:
386
  except Exception:
387
  self.model = self.model.to(self.device)
388
  self.model.eval()
389
- # Override cfg to disable features causing dim mismatches
390
  default_cfg = {
391
- 'chunk_size': 1,
392
- 'flip_aug': False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  }
 
394
  cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
395
  if isinstance(cfg, dict):
396
- cfg.update(default_cfg) # Override
397
- else:
398
- cfg = default_cfg
 
 
 
 
 
 
399
  # Inference core
400
  try:
401
  self.core = core_cls(self.model, cfg=cfg)
@@ -426,6 +570,7 @@ def load(self) -> Optional[Any]:
426
  logger.error(f"Failed to load MatAnyone: {e}")
427
  logger.debug(traceback.format_exc())
428
  return None
 
429
  def cleanup(self):
430
  self.adapter = None
431
  self.core = None
@@ -437,6 +582,7 @@ def cleanup(self):
437
  self.model = None
438
  if torch.cuda.is_available():
439
  torch.cuda.empty_cache()
 
440
  def get_info(self) -> Dict[str, Any]:
441
  return {
442
  "loaded": self.adapter is not None,
@@ -445,6 +591,7 @@ def get_info(self) -> Dict[str, Any]:
445
  "load_time": self.load_time,
446
  "model_type": type(self.model).__name__ if self.model else None,
447
  }
 
448
  def debug_shapes(self, image, mask, tag: str = ""):
449
  try:
450
  tv_img = torch.as_tensor(image)
@@ -454,6 +601,7 @@ def debug_shapes(self, image, mask, tag: str = ""):
454
  logger.info(f"[{tag}:mask ] shape={tuple(tv_msk.shape)} dtype={tv_msk.dtype}")
455
  except Exception as e:
456
  logger.info(f"[{tag}] debug error: {e}")
 
457
  # ---------------------------------------------------------------------------
458
  # Public symbols
459
  # ---------------------------------------------------------------------------
@@ -469,6 +617,7 @@ def debug_shapes(self, image, mask, tag: str = ""):
469
  "_compute_scaled_size",
470
  "debug_shapes",
471
  ]
 
472
  # ---------------------------------------------------------------------------
473
  # Optional CLI for quick testing (no circular imports)
474
  # ---------------------------------------------------------------------------
 
11
  - Added: Force chunk_size=1, flip_aug=False in cfg to avoid dim mismatches
12
  - Added: Pad to multiple of 16 to avoid transformer patch issues
13
  - Added: Prefer fp16 over bf16 for Tesla T4 compatibility
14
+ - New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
15
+ - New: Full default cfg from official config.json to ensure keys like mem_every are present
16
  """
17
  from __future__ import annotations
18
  import os
 
26
  import inspect
27
  import threading
28
  logger = logging.getLogger(__name__)
29
+
30
+ # EasyDict polyfill (recursive dict with dot access)
31
+ class EasyDict(dict):
32
+ def __init__(self, d=None, **kwargs):
33
+ if d is None:
34
+ d = {}
35
+ if kwargs:
36
+ d.update(**kwargs)
37
+ for k, v in d.items():
38
+ if isinstance(v, dict):
39
+ self[k] = EasyDict(v)
40
+ elif isinstance(v, list):
41
+ self[k] = [EasyDict(i) if isinstance(i, dict) else i for i in v]
42
+ else:
43
+ self[k] = v
44
+
45
+ def __getattr__(self, name):
46
+ try:
47
+ return self[name]
48
+ except KeyError:
49
+ raise AttributeError(name)
50
+
51
+ def __setattr__(self, name, value):
52
+ self[name] = value
53
+
54
+ def __delattr__(self, name):
55
+ del self[name]
56
+
57
  # ---------------------------------------------------------------------------
58
  # Utilities (shapes, dtype, scaling)
59
  # ---------------------------------------------------------------------------
 
64
  if pref == "cpu":
65
  return "cpu"
66
  return "cuda" if torch.cuda.is_available() else "cpu"
67
+
68
  def _as_tensor_on_device(x, device: str) -> torch.Tensor:
69
  if isinstance(x, torch.Tensor):
70
  return x.to(device, non_blocking=True)
71
  return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
72
+
73
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
74
  """
75
  Normalize input to BCHW (image) or B1HW (mask).
 
104
  x = x.repeat(1, 3, 1, 1)
105
  x = x.clamp_(0.0, 1.0)
106
  return x
107
+
108
  def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
109
  if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
110
  return img_bchw[0]
111
  return img_bchw
112
+
113
  def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
114
  if msk_b1hw is None:
115
  return None
 
118
  if msk_b1hw.ndim == 3 and msk_b1hw.shape[0] == 1:
119
  return msk_b1hw
120
  raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
121
+
122
  def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask: bool = False) -> Optional[torch.Tensor]:
123
  if x is None:
124
  return None
 
126
  return x
127
  mode = "nearest" if is_mask else "bilinear"
128
  return F.interpolate(x, size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
129
+
130
  def _to_b1hw_alpha(alpha, device: str) -> torch.Tensor:
131
  t = torch.as_tensor(alpha, device=device).float()
132
  if t.ndim == 2:
 
153
  if t.shape[1] != 1:
154
  t = t[:, :1]
155
  return t.clamp_(0.0, 1.0).contiguous()
156
+
157
  def _to_2d_alpha_numpy(x) -> np.ndarray:
158
  t = torch.as_tensor(x).float()
159
  while t.ndim > 2:
 
166
  t = t.clamp_(0.0, 1.0)
167
  out = t.detach().cpu().numpy().astype(np.float32)
168
  return np.ascontiguousarray(out)
169
+
170
  def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
171
  if h <= 0 or w <= 0:
172
  return h, w, 1.0
 
176
  nh = max(128, int(round(h * s))) # Force min 128 to avoid small-res bugs
177
  nw = max(128, int(round(w * s)))
178
  return nh, nw, s
179
+
180
  def _pad_to_multiple(t: Optional[torch.Tensor], multiple: int = 16) -> Optional[torch.Tensor]:
181
  if t is None:
182
  return None
 
194
  if t.ndim == 2: # Shouldn't happen
195
  t = t.squeeze(0)
196
  return t
197
+
198
  def debug_shapes(tag: str, image, mask) -> None:
199
  def _info(name, v):
200
  try:
 
206
  logger.info(f"[{tag}:{name}] type={type(v)} err={e}")
207
  _info("image", image)
208
  _info("mask", mask)
209
+
210
  # ---------------------------------------------------------------------------
211
  # Precision selection
212
  # ---------------------------------------------------------------------------
 
222
  if bf16_ok:
223
  return torch.bfloat16, True, torch.bfloat16
224
  return torch.float32, False, None
225
+
226
  # ---------------------------------------------------------------------------
227
  # Stateful Adapter around InferenceCore
228
  # ---------------------------------------------------------------------------
 
257
  except Exception:
258
  self._has_first_frame_pred = True
259
  self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
260
+
261
  def reset(self):
262
  with self._lock:
263
  try:
 
266
  except Exception:
267
  pass
268
  self.started = False
269
+
270
  def _scaled_ladder(self, H: int, W: int) -> List[Tuple[int, int]]:
271
  nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
272
  sizes = [(nh, nw)]
 
279
  if sizes[-1] != (cur_h, cur_w):
280
  sizes.append((cur_h, cur_w))
281
  return sizes
282
+
283
  def _to_alpha(self, out_prob):
284
  if self._has_prob_to_mask:
285
  try:
 
292
  if t.ndim == 3:
293
  return t[0] if t.shape[0] >= 1 else t.mean(0)
294
  return t
295
+
296
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
297
  """
298
  Returns a 2-D float32 alpha [H,W].
 
375
  if mask_1hw is not None:
376
  return _to_2d_alpha_numpy(mask_1hw)
377
  return np.full((H, W), 0.5, dtype=np.float32)
378
+
379
  # ---------------------------------------------------------------------------
380
  # Loader
381
  # ---------------------------------------------------------------------------
 
392
  self.adapter = None
393
  self.model_id = "PeiqingYang/MatAnyone"
394
  self.load_time = 0.0
395
+
396
  # --- Robust imports (works with different packaging layouts) ---
397
  def _import_model_and_core(self):
398
  model_cls = core_cls = None
 
420
  if model_cls is None or core_cls is None:
421
  raise ImportError("Could not import MatAnyone / InferenceCore: " + " | ".join(err_msgs))
422
  return model_cls, core_cls
423
+
424
  def load(self) -> Optional[Any]:
425
  logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
426
  t0 = time.time()
 
435
  except Exception:
436
  self.model = self.model.to(self.device)
437
  self.model.eval()
438
+ # Full default cfg from official config.json
439
  default_cfg = {
440
+ "amp": False,
441
+ "chunk_size": -1,
442
+ "flip_aug": False,
443
+ "long_term": {
444
+ "buffer_tokens": 2000,
445
+ "count_usage": True,
446
+ "max_mem_frames": 10,
447
+ "max_num_tokens": 10000,
448
+ "min_mem_frames": 5,
449
+ "num_prototypes": 128
450
+ },
451
+ "max_internal_size": -1,
452
+ "max_mem_frames": 5,
453
+ "mem_every": 5,
454
+ "model": {
455
+ "aux_loss": {
456
+ "query": {
457
+ "enabled": True,
458
+ "weight": 0.01
459
+ },
460
+ "sensory": {
461
+ "enabled": True,
462
+ "weight": 0.01
463
+ }
464
+ },
465
+ "embed_dim": 256,
466
+ "key_dim": 64,
467
+ "mask_decoder": {
468
+ "up_dims": [256, 128, 128, 64, 16]
469
+ },
470
+ "mask_encoder": {
471
+ "final_dim": 256,
472
+ "type": "resnet18"
473
+ },
474
+ "object_summarizer": {
475
+ "add_pe": True,
476
+ "embed_dim": 256,
477
+ "num_summaries": 16
478
+ },
479
+ "object_transformer": {
480
+ "embed_dim": 256,
481
+ "ff_dim": 2048,
482
+ "num_blocks": 3,
483
+ "num_heads": 8,
484
+ "num_queries": 16,
485
+ "pixel_self_attention": {
486
+ "add_pe_to_qkv": [True, True, False]
487
+ },
488
+ "query_self_attention": {
489
+ "add_pe_to_qkv": [True, True, False]
490
+ },
491
+ "read_from_memory": {
492
+ "add_pe_to_qkv": [True, True, False]
493
+ },
494
+ "read_from_past": {
495
+ "add_pe_to_qkv": [True, True, False]
496
+ },
497
+ "read_from_pixel": {
498
+ "add_pe_to_qkv": [True, True, False],
499
+ "input_add_pe": False,
500
+ "input_norm": False
501
+ },
502
+ "read_from_query": {
503
+ "add_pe_to_qkv": [True, True, False],
504
+ "output_norm": False
505
+ }
506
+ },
507
+ "pixel_dim": 256,
508
+ "pixel_encoder": {
509
+ "ms_dims": [1024, 512, 256, 64, 3],
510
+ "type": "resnet50"
511
+ },
512
+ "pixel_mean": [0.485, 0.456, 0.406],
513
+ "pixel_pe_scale": 32,
514
+ "pixel_pe_temperature": 128,
515
+ "pixel_std": [0.229, 0.224, 0.225],
516
+ "pretrained_resnet": False,
517
+ "sensory_dim": 256,
518
+ "value_dim": 256
519
+ },
520
+ "output_dir": None,
521
+ "save_all": True,
522
+ "save_aux": False,
523
+ "save_scores": False,
524
+ "stagger_updates": 5,
525
+ "top_k": 30,
526
+ "use_all_masks": False,
527
+ "use_long_term": False,
528
+ "visualize": False,
529
+ "weights": "pretrained_models/matanyone.pth"
530
  }
531
+ # Get cfg from model if available, else default
532
  cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
533
  if isinstance(cfg, dict):
534
+ cfg = dict(cfg) # Copy to avoid modifying model.cfg
535
+ # Override specific values
536
+ overrides = {
537
+ 'chunk_size': 1,
538
+ 'flip_aug': False,
539
+ }
540
+ cfg.update(overrides)
541
+ # Convert to EasyDict for dot access
542
+ cfg = EasyDict(cfg)
543
  # Inference core
544
  try:
545
  self.core = core_cls(self.model, cfg=cfg)
 
570
  logger.error(f"Failed to load MatAnyone: {e}")
571
  logger.debug(traceback.format_exc())
572
  return None
573
+
574
  def cleanup(self):
575
  self.adapter = None
576
  self.core = None
 
582
  self.model = None
583
  if torch.cuda.is_available():
584
  torch.cuda.empty_cache()
585
+
586
  def get_info(self) -> Dict[str, Any]:
587
  return {
588
  "loaded": self.adapter is not None,
 
591
  "load_time": self.load_time,
592
  "model_type": type(self.model).__name__ if self.model else None,
593
  }
594
+
595
  def debug_shapes(self, image, mask, tag: str = ""):
596
  try:
597
  tv_img = torch.as_tensor(image)
 
601
  logger.info(f"[{tag}:mask ] shape={tuple(tv_msk.shape)} dtype={tv_msk.dtype}")
602
  except Exception as e:
603
  logger.info(f"[{tag}] debug error: {e}")
604
+
605
  # ---------------------------------------------------------------------------
606
  # Public symbols
607
  # ---------------------------------------------------------------------------
 
617
  "_compute_scaled_size",
618
  "debug_shapes",
619
  ]
620
+
621
  # ---------------------------------------------------------------------------
622
  # Optional CLI for quick testing (no circular imports)
623
  # ---------------------------------------------------------------------------