MogensR commited on
Commit
f03fb20
·
1 Parent(s): 8a850cc
Files changed (1) hide show
  1. pipeline.py +78 -17
pipeline.py CHANGED
@@ -44,6 +44,7 @@
44
  from typing import Optional, Tuple, Dict, Any, Union
45
 
46
  import numpy as np
 
47
 
48
  # Try to apply GPU/perf tuning early if present
49
  try:
@@ -405,6 +406,57 @@ def _build_stage_a_checkerboard_from_mask(
405
  writer.release()
406
  return ok_any
407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  # --------------------------------------------------------------------------------------
409
  # SAM2 Integration (robust to different build_sam2 signatures)
410
  # --------------------------------------------------------------------------------------
@@ -414,6 +466,8 @@ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
414
  - config_file vs model_cfg
415
  - checkpoint vs ckpt_path vs weights
416
  - optional device kwarg
 
 
417
  """
418
  meta = {"sam2_import_ok": False, "sam2_init_ok": False}
419
  try:
@@ -425,22 +479,18 @@ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
425
  return None, False, meta
426
 
427
  device = _pick_device("SAM2_DEVICE")
428
- cfg = os.environ.get("SAM2_MODEL_CFG", "configs/sam2/sam2_hiera_l.yaml")
 
429
  ckpt = os.environ.get("SAM2_CHECKPOINT", "")
430
 
431
- try:
432
  params = set(inspect.signature(build_sam2).parameters.keys())
433
  kwargs = {}
434
-
435
  # Config arg
436
  if "config_file" in params:
437
- kwargs["config_file"] = cfg
438
  elif "model_cfg" in params:
439
- kwargs["model_cfg"] = cfg
440
- else:
441
- # if neither is present, try positional later
442
- pass
443
-
444
  # Checkpoint arg
445
  if ckpt:
446
  if "checkpoint" in params:
@@ -449,22 +499,33 @@ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
449
  kwargs["ckpt_path"] = ckpt
450
  elif "weights" in params:
451
  kwargs["weights"] = ckpt
452
-
453
- # Device (if supported via kwarg)
454
  if "device" in params:
455
  kwargs["device"] = device
456
-
457
- # Try keyword call first
458
  try:
459
- sam = build_sam2(**kwargs)
460
  except TypeError:
461
- # Fallback to positional (cfg, ckpt?, device?)
462
- pos = [cfg]
463
  if ckpt:
464
  pos.append(ckpt)
465
  if "device" not in kwargs:
466
  pos.append(device)
467
- sam = build_sam2(*pos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
  predictor = SAM2ImagePredictor(sam)
470
  meta.update({
 
44
  from typing import Optional, Tuple, Dict, Any, Union
45
 
46
  import numpy as np
47
+ import yaml # for SAM2 config introspection
48
 
49
  # Try to apply GPU/perf tuning early if present
50
  try:
 
406
  writer.release()
407
  return ok_any
408
 
409
+ # --------------------------------------------------------------------------------------
410
+ # SAM2 helpers (config resolution & robust loader)
411
+ # --------------------------------------------------------------------------------------
412
+ def _resolve_sam2_cfg(cfg_str: str) -> str:
413
+ """Make the SAM2 config path absolute (prefer inside TP_SAM2)."""
414
+ cfg_path = Path(cfg_str)
415
+ if not cfg_path.is_absolute():
416
+ candidate = TP_SAM2 / cfg_path
417
+ if candidate.exists():
418
+ return str(candidate)
419
+ if cfg_path.exists():
420
+ return str(cfg_path)
421
+ # Last resort: common defaults inside the repo
422
+ for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
423
+ p = TP_SAM2 / name
424
+ if p.exists():
425
+ return str(p)
426
+ return str(cfg_str) # let build_sam2 raise a clear error
427
+
428
+ def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
429
+ """
430
+ If the given config references 'hieradet', try to find a 'hiera' config in the repo and return it.
431
+ """
432
+ try:
433
+ with open(cfg_path, "r") as f:
434
+ data = yaml.safe_load(f)
435
+ # Look for target under model.image_encoder.trunk._target_ (Hydra style)
436
+ target = None
437
+ model = data.get("model", {})
438
+ enc = (model.get("image_encoder") or {})
439
+ trunk = (enc.get("trunk") or {})
440
+ target = trunk.get("_target_") or trunk.get("target")
441
+ if isinstance(target, str) and "hieradet" in target:
442
+ # Search all yaml files under TP_SAM2/configs for those that reference ".hiera."
443
+ for y in TP_SAM2.rglob("*.yaml"):
444
+ try:
445
+ with open(y, "r") as f2:
446
+ d2 = yaml.safe_load(f2)
447
+ m2 = (d2 or {}).get("model", {})
448
+ e2 = (m2.get("image_encoder") or {})
449
+ t2 = (e2.get("trunk") or {})
450
+ tgt2 = t2.get("_target_") or t2.get("target")
451
+ if isinstance(tgt2, str) and ".hiera." in tgt2:
452
+ logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
453
+ return str(y)
454
+ except Exception:
455
+ continue
456
+ except Exception:
457
+ pass
458
+ return None
459
+
460
  # --------------------------------------------------------------------------------------
461
  # SAM2 Integration (robust to different build_sam2 signatures)
462
  # --------------------------------------------------------------------------------------
 
466
  - config_file vs model_cfg
467
  - checkpoint vs ckpt_path vs weights
468
  - optional device kwarg
469
+ - absolute config resolution (inside third_party/sam2)
470
+ - auto-fix if config references 'hieradet' but repo has 'hiera'
471
  """
472
  meta = {"sam2_import_ok": False, "sam2_init_ok": False}
473
  try:
 
479
  return None, False, meta
480
 
481
  device = _pick_device("SAM2_DEVICE")
482
+ cfg_env = os.environ.get("SAM2_MODEL_CFG", "configs/sam2/sam2_hiera_l.yaml")
483
+ cfg = _resolve_sam2_cfg(cfg_env)
484
  ckpt = os.environ.get("SAM2_CHECKPOINT", "")
485
 
486
+ def _try_build(cfg_path: str):
487
  params = set(inspect.signature(build_sam2).parameters.keys())
488
  kwargs = {}
 
489
  # Config arg
490
  if "config_file" in params:
491
+ kwargs["config_file"] = cfg_path
492
  elif "model_cfg" in params:
493
+ kwargs["model_cfg"] = cfg_path
 
 
 
 
494
  # Checkpoint arg
495
  if ckpt:
496
  if "checkpoint" in params:
 
499
  kwargs["ckpt_path"] = ckpt
500
  elif "weights" in params:
501
  kwargs["weights"] = ckpt
502
+ # Device
 
503
  if "device" in params:
504
  kwargs["device"] = device
505
+ # Try keywords first, then positional fallback
 
506
  try:
507
+ return build_sam2(**kwargs)
508
  except TypeError:
509
+ pos = [cfg_path]
 
510
  if ckpt:
511
  pos.append(ckpt)
512
  if "device" not in kwargs:
513
  pos.append(device)
514
+ return build_sam2(*pos)
515
+
516
+ try:
517
+ try:
518
+ sam = _try_build(cfg)
519
+ except Exception as e1:
520
+ msg = str(e1)
521
+ # If the config is using 'hieradet', try to swap to a 'hiera' config
522
+ alt_cfg = _find_hiera_config_if_hieradet(cfg)
523
+ if alt_cfg:
524
+ logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
525
+ sam = _try_build(alt_cfg)
526
+ cfg = alt_cfg
527
+ else:
528
+ raise
529
 
530
  predictor = SAM2ImagePredictor(sam)
531
  meta.update({