Instructions to use Gertlek/DetectiveSAM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sam2
How to use Gertlek/DetectiveSAM with sam2:
# Use SAM2 with images import torch from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(<your_image>) masks, _, _ = predictor.predict(<input_prompts>)# Use SAM2 with videos import torch from sam2.sam2_video_predictor import SAM2VideoPredictor predictor = SAM2VideoPredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state(<your_video>) # add new prompts and instantly get the output on the same frame frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>): # propagate the prompts to get masklets throughout the video for frame_idx, object_ids, masks in predictor.propagate_in_video(state): ... - Notebooks
- Google Colab
- Kaggle
File size: 5,081 Bytes
7b474fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | from __future__ import annotations
import math
from pathlib import Path
import pytest
from detectivesam_inference.checkpoint import resolve_checkpoint_path
from detectivesam_inference.dataset import PairDataset, prepare_sample
from detectivesam_inference.metrics import compute_f1, compute_iou, summarize_results
from detectivesam_inference.runtime import DetectiveSAMRunner, get_repo_root
def assert_exact(value: float | None, expected: float | None, *, abs_tol: float = 1e-12) -> None:
assert value is not None
assert expected is not None
assert math.isclose(value, expected, rel_tol=0.0, abs_tol=abs_tol)
@pytest.fixture(scope="module")
def repo_root() -> Path:
return get_repo_root()
@pytest.fixture(scope="module")
def baseline_runner() -> DetectiveSAMRunner:
return DetectiveSAMRunner(checkpoint_path="detective_sam", device="cpu")
@pytest.fixture(scope="module")
def sota_runner() -> DetectiveSAMRunner:
return DetectiveSAMRunner(checkpoint_path="detective_sam_sota", device="cpu")
def predict_metrics(
runner: DetectiveSAMRunner,
*,
source_path: Path,
target_path: Path,
mask_path: Path,
) -> tuple[float, float]:
sample = prepare_sample(
source_path=source_path,
target_path=target_path,
mask_path=mask_path,
img_size=runner.config.img_size,
perturbation_type=runner.config.perturbation_type,
perturbation_intensity=runner.config.perturbation_intensity,
)
prediction = runner.predict_sample(sample, threshold=0.5)
true_mask = sample.mask.squeeze().numpy().astype("uint8")
return compute_iou(prediction.pred_mask, true_mask), compute_f1(prediction.pred_mask, true_mask)
def test_checkpoint_alias_resolution(repo_root: Path) -> None:
assert resolve_checkpoint_path("detective_sam", repo_root) == repo_root / "checkpoints" / "model_epoch22_batch999_score1.1114.pth"
assert resolve_checkpoint_path("detective_sam_sota", repo_root) == repo_root / "checkpoints" / "detective_sam_sota.pth"
def test_baseline_banana_demo_metrics(repo_root: Path, baseline_runner: DetectiveSAMRunner) -> None:
demo_root = repo_root / "demo" / "cocoglide"
iou, f1 = predict_metrics(
baseline_runner,
source_path=demo_root / "source" / "banana_28809.png",
target_path=demo_root / "target" / "banana_28809.png",
mask_path=demo_root / "mask" / "banana_28809.png",
)
assert_exact(iou, 0.8566427949370513)
assert_exact(f1, 0.9227868680750683)
def test_sota_flux_demo_metrics(repo_root: Path, sota_runner: DetectiveSAMRunner) -> None:
demo_root = repo_root / "demo" / "flux_test"
iou, f1 = predict_metrics(
sota_runner,
source_path=demo_root / "source" / "548.png",
target_path=demo_root / "target" / "548.png",
mask_path=demo_root / "mask" / "548.png",
)
assert_exact(iou, 0.8703024868799283)
assert_exact(f1, 0.9306542583192329)
def test_sota_qwen_demo_metrics(repo_root: Path, sota_runner: DetectiveSAMRunner) -> None:
demo_root = repo_root / "demo" / "qwen_test"
iou, f1 = predict_metrics(
sota_runner,
source_path=demo_root / "source" / "166.png",
target_path=demo_root / "target" / "166.png",
mask_path=demo_root / "mask" / "166.png",
)
assert_exact(iou, 0.8297306693388413)
assert_exact(f1, 0.9069429542203147)
def test_baseline_cocoglide_eval_summary(repo_root: Path, baseline_runner: DetectiveSAMRunner) -> None:
dataset = PairDataset(
root_dir=repo_root / "demo" / "cocoglide",
img_size=baseline_runner.config.img_size,
perturbation_type=baseline_runner.config.perturbation_type,
perturbation_intensity=baseline_runner.config.perturbation_intensity,
)
per_sample_results: list[dict[str, float | str | None]] = []
for sample in dataset:
prediction = baseline_runner.predict_sample(sample, threshold=0.5)
true_mask = sample.mask.squeeze().numpy().astype("uint8")
per_sample_results.append(
{
"name": sample.name,
"iou": compute_iou(prediction.pred_mask, true_mask),
"f1": compute_f1(prediction.pred_mask, true_mask),
}
)
summary = summarize_results(per_sample_results)
assert summary["num_samples"] == 5
assert summary["num_samples_with_gt"] == 5
assert_exact(summary["mean_iou"], 0.5092573070481035)
assert_exact(summary["mean_f1"], 0.6509390858765342)
expected_by_name = {
"airplane_139871": (0.41829717560376584, 0.5898582931686339),
"banana_28809": (0.8566427949370513, 0.9227868680750683),
"giraffe_296969": (0.22833093957714018, 0.3717743031951054),
"train_221213": (0.547253866814856, 0.7073872989458688),
"tv_453722": (0.49576175830770386, 0.662888665997994),
}
for result in per_sample_results:
expected_iou, expected_f1 = expected_by_name[result["name"]]
assert_exact(result["iou"], expected_iou)
assert_exact(result["f1"], expected_f1)
|