fix 11
Browse files- pipeline.py +78 -17
pipeline.py
CHANGED
|
@@ -44,6 +44,7 @@
|
|
| 44 |
from typing import Optional, Tuple, Dict, Any, Union
|
| 45 |
|
| 46 |
import numpy as np
|
|
|
|
| 47 |
|
| 48 |
# Try to apply GPU/perf tuning early if present
|
| 49 |
try:
|
|
@@ -405,6 +406,57 @@ def _build_stage_a_checkerboard_from_mask(
|
|
| 405 |
writer.release()
|
| 406 |
return ok_any
|
| 407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
# --------------------------------------------------------------------------------------
|
| 409 |
# SAM2 Integration (robust to different build_sam2 signatures)
|
| 410 |
# --------------------------------------------------------------------------------------
|
|
@@ -414,6 +466,8 @@ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
|
|
| 414 |
- config_file vs model_cfg
|
| 415 |
- checkpoint vs ckpt_path vs weights
|
| 416 |
- optional device kwarg
|
|
|
|
|
|
|
| 417 |
"""
|
| 418 |
meta = {"sam2_import_ok": False, "sam2_init_ok": False}
|
| 419 |
try:
|
|
@@ -425,22 +479,18 @@ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
|
|
| 425 |
return None, False, meta
|
| 426 |
|
| 427 |
device = _pick_device("SAM2_DEVICE")
|
| 428 |
-
|
|
|
|
| 429 |
ckpt = os.environ.get("SAM2_CHECKPOINT", "")
|
| 430 |
|
| 431 |
-
|
| 432 |
params = set(inspect.signature(build_sam2).parameters.keys())
|
| 433 |
kwargs = {}
|
| 434 |
-
|
| 435 |
# Config arg
|
| 436 |
if "config_file" in params:
|
| 437 |
-
kwargs["config_file"] =
|
| 438 |
elif "model_cfg" in params:
|
| 439 |
-
kwargs["model_cfg"] =
|
| 440 |
-
else:
|
| 441 |
-
# if neither is present, try positional later
|
| 442 |
-
pass
|
| 443 |
-
|
| 444 |
# Checkpoint arg
|
| 445 |
if ckpt:
|
| 446 |
if "checkpoint" in params:
|
|
@@ -449,22 +499,33 @@ def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]:
|
|
| 449 |
kwargs["ckpt_path"] = ckpt
|
| 450 |
elif "weights" in params:
|
| 451 |
kwargs["weights"] = ckpt
|
| 452 |
-
|
| 453 |
-
# Device (if supported via kwarg)
|
| 454 |
if "device" in params:
|
| 455 |
kwargs["device"] = device
|
| 456 |
-
|
| 457 |
-
# Try keyword call first
|
| 458 |
try:
|
| 459 |
-
|
| 460 |
except TypeError:
|
| 461 |
-
|
| 462 |
-
pos = [cfg]
|
| 463 |
if ckpt:
|
| 464 |
pos.append(ckpt)
|
| 465 |
if "device" not in kwargs:
|
| 466 |
pos.append(device)
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
predictor = SAM2ImagePredictor(sam)
|
| 470 |
meta.update({
|
|
|
|
| 44 |
from typing import Optional, Tuple, Dict, Any, Union
|
| 45 |
|
| 46 |
import numpy as np
|
| 47 |
+
import yaml # for SAM2 config introspection
|
| 48 |
|
| 49 |
# Try to apply GPU/perf tuning early if present
|
| 50 |
try:
|
|
|
|
| 406 |
writer.release()
|
| 407 |
return ok_any
|
| 408 |
|
| 409 |
+
# --------------------------------------------------------------------------------------
|
| 410 |
+
# SAM2 helpers (config resolution & robust loader)
|
| 411 |
+
# --------------------------------------------------------------------------------------
|
| 412 |
+
def _resolve_sam2_cfg(cfg_str: str) -> str:
|
| 413 |
+
"""Make the SAM2 config path absolute (prefer inside TP_SAM2)."""
|
| 414 |
+
cfg_path = Path(cfg_str)
|
| 415 |
+
if not cfg_path.is_absolute():
|
| 416 |
+
candidate = TP_SAM2 / cfg_path
|
| 417 |
+
if candidate.exists():
|
| 418 |
+
return str(candidate)
|
| 419 |
+
if cfg_path.exists():
|
| 420 |
+
return str(cfg_path)
|
| 421 |
+
# Last resort: common defaults inside the repo
|
| 422 |
+
for name in ["configs/sam2/sam2_hiera_l.yaml", "configs/sam2/sam2_hiera_b.yaml", "configs/sam2/sam2_hiera_s.yaml"]:
|
| 423 |
+
p = TP_SAM2 / name
|
| 424 |
+
if p.exists():
|
| 425 |
+
return str(p)
|
| 426 |
+
return str(cfg_str) # let build_sam2 raise a clear error
|
| 427 |
+
|
| 428 |
+
def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]:
|
| 429 |
+
"""
|
| 430 |
+
If the given config references 'hieradet', try to find a 'hiera' config in the repo and return it.
|
| 431 |
+
"""
|
| 432 |
+
try:
|
| 433 |
+
with open(cfg_path, "r") as f:
|
| 434 |
+
data = yaml.safe_load(f)
|
| 435 |
+
# Look for target under model.image_encoder.trunk._target_ (Hydra style)
|
| 436 |
+
target = None
|
| 437 |
+
model = data.get("model", {})
|
| 438 |
+
enc = (model.get("image_encoder") or {})
|
| 439 |
+
trunk = (enc.get("trunk") or {})
|
| 440 |
+
target = trunk.get("_target_") or trunk.get("target")
|
| 441 |
+
if isinstance(target, str) and "hieradet" in target:
|
| 442 |
+
# Search all yaml files under TP_SAM2/configs for those that reference ".hiera."
|
| 443 |
+
for y in TP_SAM2.rglob("*.yaml"):
|
| 444 |
+
try:
|
| 445 |
+
with open(y, "r") as f2:
|
| 446 |
+
d2 = yaml.safe_load(f2)
|
| 447 |
+
m2 = (d2 or {}).get("model", {})
|
| 448 |
+
e2 = (m2.get("image_encoder") or {})
|
| 449 |
+
t2 = (e2.get("trunk") or {})
|
| 450 |
+
tgt2 = t2.get("_target_") or t2.get("target")
|
| 451 |
+
if isinstance(tgt2, str) and ".hiera." in tgt2:
|
| 452 |
+
logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}")
|
| 453 |
+
return str(y)
|
| 454 |
+
except Exception:
|
| 455 |
+
continue
|
| 456 |
+
except Exception:
|
| 457 |
+
pass
|
| 458 |
+
return None
|
| 459 |
+
|
| 460 |
# --------------------------------------------------------------------------------------
|
| 461 |
# SAM2 Integration (robust to different build_sam2 signatures)
|
| 462 |
# --------------------------------------------------------------------------------------
|
|
|
|
| 466 |
- config_file vs model_cfg
|
| 467 |
- checkpoint vs ckpt_path vs weights
|
| 468 |
- optional device kwarg
|
| 469 |
+
- absolute config resolution (inside third_party/sam2)
|
| 470 |
+
- auto-fix if config references 'hieradet' but repo has 'hiera'
|
| 471 |
"""
|
| 472 |
meta = {"sam2_import_ok": False, "sam2_init_ok": False}
|
| 473 |
try:
|
|
|
|
| 479 |
return None, False, meta
|
| 480 |
|
| 481 |
device = _pick_device("SAM2_DEVICE")
|
| 482 |
+
cfg_env = os.environ.get("SAM2_MODEL_CFG", "configs/sam2/sam2_hiera_l.yaml")
|
| 483 |
+
cfg = _resolve_sam2_cfg(cfg_env)
|
| 484 |
ckpt = os.environ.get("SAM2_CHECKPOINT", "")
|
| 485 |
|
| 486 |
+
def _try_build(cfg_path: str):
|
| 487 |
params = set(inspect.signature(build_sam2).parameters.keys())
|
| 488 |
kwargs = {}
|
|
|
|
| 489 |
# Config arg
|
| 490 |
if "config_file" in params:
|
| 491 |
+
kwargs["config_file"] = cfg_path
|
| 492 |
elif "model_cfg" in params:
|
| 493 |
+
kwargs["model_cfg"] = cfg_path
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
# Checkpoint arg
|
| 495 |
if ckpt:
|
| 496 |
if "checkpoint" in params:
|
|
|
|
| 499 |
kwargs["ckpt_path"] = ckpt
|
| 500 |
elif "weights" in params:
|
| 501 |
kwargs["weights"] = ckpt
|
| 502 |
+
# Device
|
|
|
|
| 503 |
if "device" in params:
|
| 504 |
kwargs["device"] = device
|
| 505 |
+
# Try keywords first, then positional fallback
|
|
|
|
| 506 |
try:
|
| 507 |
+
return build_sam2(**kwargs)
|
| 508 |
except TypeError:
|
| 509 |
+
pos = [cfg_path]
|
|
|
|
| 510 |
if ckpt:
|
| 511 |
pos.append(ckpt)
|
| 512 |
if "device" not in kwargs:
|
| 513 |
pos.append(device)
|
| 514 |
+
return build_sam2(*pos)
|
| 515 |
+
|
| 516 |
+
try:
|
| 517 |
+
try:
|
| 518 |
+
sam = _try_build(cfg)
|
| 519 |
+
except Exception as e1:
|
| 520 |
+
msg = str(e1)
|
| 521 |
+
# If the config is using 'hieradet', try to swap to a 'hiera' config
|
| 522 |
+
alt_cfg = _find_hiera_config_if_hieradet(cfg)
|
| 523 |
+
if alt_cfg:
|
| 524 |
+
logger.info(f"SAM2: retrying with alt config: {alt_cfg}")
|
| 525 |
+
sam = _try_build(alt_cfg)
|
| 526 |
+
cfg = alt_cfg
|
| 527 |
+
else:
|
| 528 |
+
raise
|
| 529 |
|
| 530 |
predictor = SAM2ImagePredictor(sam)
|
| 531 |
meta.update({
|