Spaces:
Running
on
Zero
Running
on
Zero
Upload full code for Space
Browse files- README.md +1 -14
- backbones/__init__.py +24 -0
- backbones/__pycache__/__init__.cpython-312.pyc +0 -0
- backbones/__pycache__/dino_v2.cpython-312.pyc +0 -0
- backbones/__pycache__/dino_v3.cpython-312.pyc +0 -0
- backbones/dino_v2.py +22 -0
- backbones/dino_v3.py +33 -0
- dataset/__pycache__/dataloader.cpython-312.pyc +0 -0
- dataset/dataloader.py +28 -0
- evaluation/__pycache__/anomaly_evaluator.cpython-312.pyc +0 -0
- evaluation/anomaly_evaluator.py +111 -0
- main.py +135 -0
- models/__pycache__/model_bank_knn.cpython-312.pyc +0 -0
- models/model_bank_knn.py +239 -0
- segmenters/__init__.py +21 -0
- segmenters/__pycache__/__init__.cpython-312.pyc +0 -0
- segmenters/__pycache__/base_segmenter.cpython-312.pyc +0 -0
- segmenters/__pycache__/pca_segmenter.cpython-312.pyc +0 -0
- segmenters/__pycache__/sam3.cpython-312.pyc +0 -0
- segmenters/base_segmenter.py +14 -0
- segmenters/pca_segmenter.py +93 -0
- segmenters/sam2.py +105 -0
- segmenters/sam3.py +97 -0
- utils/__pycache__/visualize.cpython-312.pyc +0 -0
- utils/visualize.py +114 -0
README.md
CHANGED
|
@@ -1,14 +1 @@
|
|
| 1 |
-
|
| 2 |
-
title: AnomalyDetection
|
| 3 |
-
emoji: 📈
|
| 4 |
-
colorFrom: green
|
| 5 |
-
colorTo: yellow
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 6.1.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
-
short_description: Anomaly detection using Dino v3
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
## To run and reproduce the PatchKNN model, use interactive demo notebook
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backbones/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import timm
|
| 2 |
+
|
| 3 |
+
from backbones.dino_v2 import build_dinov2_small, build_dinov2_base, build_dinov2_large
|
| 4 |
+
from backbones.dino_v3 import build_dinov3_small, build_dinov3_base, build_dinov3_large
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Model registry for backbones
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
_BACKBONES = {
|
| 12 |
+
"dinov2_small": build_dinov2_small,
|
| 13 |
+
"dinov2_base": build_dinov2_base,
|
| 14 |
+
"dinov2_large": build_dinov2_large,
|
| 15 |
+
"dinov3_small": build_dinov3_small,
|
| 16 |
+
"dinov3_base": build_dinov3_base,
|
| 17 |
+
"dinov3_large": build_dinov3_large,
|
| 18 |
+
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
def get_backbone(name: str, **kwargs):
|
| 22 |
+
if name not in _BACKBONES:
|
| 23 |
+
raise ValueError(f"Unknown backbone '{name}'. Available: {list(_BACKBONES)}")
|
| 24 |
+
return _BACKBONES[name](**kwargs)
|
backbones/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (919 Bytes). View file
|
|
|
backbones/__pycache__/dino_v2.cpython-312.pyc
ADDED
|
Binary file (940 Bytes). View file
|
|
|
backbones/__pycache__/dino_v3.cpython-312.pyc
ADDED
|
Binary file (1 kB). View file
|
|
|
backbones/dino_v2.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import timm
|
| 3 |
+
"""
|
| 4 |
+
DinoV2 backbones
|
| 5 |
+
"""
|
| 6 |
+
def build_dinov2_small(**kwargs):
|
| 7 |
+
|
| 8 |
+
model = timm.create_model('vit_small_patch14_reg4_dinov2.lvd142m',
|
| 9 |
+
pretrained=True, num_classes=0)
|
| 10 |
+
return model
|
| 11 |
+
|
| 12 |
+
def build_dinov2_base(**kwargs):
|
| 13 |
+
|
| 14 |
+
model = timm.create_model('vit_base_patch14_reg4_dinov2.lvd142m',
|
| 15 |
+
pretrained=True, num_classes=0)
|
| 16 |
+
return model
|
| 17 |
+
|
| 18 |
+
def build_dinov2_large(**kwargs):
|
| 19 |
+
|
| 20 |
+
model = timm.create_model('vit_large_patch14_dinov2.lvd142m',
|
| 21 |
+
pretrained=True, num_classes=0)
|
| 22 |
+
return model
|
backbones/dino_v3.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import timm
|
| 2 |
+
"""
|
| 3 |
+
DinoV3 backbones
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
def build_dinov3_small(**kwargs):
|
| 7 |
+
model = timm.create_model(
|
| 8 |
+
"vit_small_plus_patch16_dinov3.lvd1689m",
|
| 9 |
+
pretrained=True,
|
| 10 |
+
num_classes=0,
|
| 11 |
+
**kwargs,
|
| 12 |
+
)
|
| 13 |
+
return model
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_dinov3_base(**kwargs):
|
| 17 |
+
model = timm.create_model(
|
| 18 |
+
"vit_base_patch16_dinov3.lvd1689m",
|
| 19 |
+
pretrained=True,
|
| 20 |
+
num_classes=0,
|
| 21 |
+
**kwargs,
|
| 22 |
+
)
|
| 23 |
+
return model
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def build_dinov3_large(**kwargs):
|
| 27 |
+
model = timm.create_model(
|
| 28 |
+
"vit_large_patch16_dinov3.lvd1689m",
|
| 29 |
+
pretrained=True,
|
| 30 |
+
num_classes=0,
|
| 31 |
+
**kwargs,
|
| 32 |
+
)
|
| 33 |
+
return model
|
dataset/__pycache__/dataloader.cpython-312.pyc
ADDED
|
Binary file (921 Bytes). View file
|
|
|
dataset/dataloader.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
from anomalib.data import MVTec, MVTecAD
|
| 5 |
+
from anomalib.data.datasets.image import MVTecDataset
|
| 6 |
+
"""Loading logic :
|
| 7 |
+
Download MVTech AD dataset from the website and place it in data folder (Anomalib installation doesnt work, sorry you have to donwload manually)
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_mvtec(category: str, root: str = "./datasets/MVTec",) :
|
| 13 |
+
train_ds = MVTecDataset(
|
| 14 |
+
root=root,
|
| 15 |
+
category=category,
|
| 16 |
+
split="train"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
test_ds = MVTecDataset(
|
| 20 |
+
root=root,
|
| 21 |
+
category=category,
|
| 22 |
+
split="test"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
train_paths = train_ds.samples["image_path"].tolist()
|
| 26 |
+
test_paths = test_ds.samples["image_path"].tolist()
|
| 27 |
+
|
| 28 |
+
return train_paths, test_paths
|
evaluation/__pycache__/anomaly_evaluator.cpython-312.pyc
ADDED
|
Binary file (5.71 kB). View file
|
|
|
evaluation/anomaly_evaluator.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score
|
| 4 |
+
from skimage.measure import label, regionprops
|
| 5 |
+
|
| 6 |
+
class AnomalyEvaluator:
|
| 7 |
+
def __init__(self, pixel_subsample_rate=0.01, compute_pro=False):
|
| 8 |
+
self.subsample_rate = pixel_subsample_rate
|
| 9 |
+
self.compute_pro = compute_pro
|
| 10 |
+
self.reset()
|
| 11 |
+
|
| 12 |
+
def reset(self):
|
| 13 |
+
self.img_preds = []
|
| 14 |
+
self.img_labels = []
|
| 15 |
+
|
| 16 |
+
self.pix_preds = []
|
| 17 |
+
self.pix_labels = []
|
| 18 |
+
|
| 19 |
+
self.full_amaps = []
|
| 20 |
+
self.full_masks = []
|
| 21 |
+
|
| 22 |
+
def update(self, image_score, gt_label, anomaly_map=None, gt_mask=None):
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
self.img_preds.append(image_score)
|
| 26 |
+
self.img_labels.append(gt_label)
|
| 27 |
+
|
| 28 |
+
if anomaly_map is not None and gt_mask is not None:
|
| 29 |
+
self._update_pixel_metrics(anomaly_map, gt_mask)
|
| 30 |
+
|
| 31 |
+
def _update_pixel_metrics(self, amap, mask):
|
| 32 |
+
if mask.shape != amap.shape:
|
| 33 |
+
mask = cv2.resize(mask, (amap.shape[1], amap.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 34 |
+
|
| 35 |
+
mask = (mask > 0).astype(int)
|
| 36 |
+
|
| 37 |
+
if self.compute_pro:
|
| 38 |
+
self.full_amaps.append(amap)
|
| 39 |
+
self.full_masks.append(mask)
|
| 40 |
+
|
| 41 |
+
flat_amap = amap.flatten()
|
| 42 |
+
flat_mask = mask.flatten()
|
| 43 |
+
|
| 44 |
+
if self.compute_pro or self.subsample_rate >= 1.0:
|
| 45 |
+
self.pix_preds.extend(flat_amap)
|
| 46 |
+
self.pix_labels.extend(flat_mask)
|
| 47 |
+
else:
|
| 48 |
+
# Random Subsampling to save memory
|
| 49 |
+
num_pixels = len(flat_mask)
|
| 50 |
+
sample_size = int(num_pixels * self.subsample_rate)
|
| 51 |
+
indices = np.random.choice(num_pixels, sample_size, replace=False)
|
| 52 |
+
self.pix_preds.extend(flat_amap[indices])
|
| 53 |
+
self.pix_labels.extend(flat_mask[indices])
|
| 54 |
+
|
| 55 |
+
def compute(self):
|
| 56 |
+
results = {}
|
| 57 |
+
y_true = np.array(self.img_labels)
|
| 58 |
+
y_score = np.array(self.img_preds)
|
| 59 |
+
|
| 60 |
+
results['image_auroc'] = roc_auc_score(y_true, y_score)
|
| 61 |
+
|
| 62 |
+
results['image_ap'] = average_precision_score(y_true, y_score)
|
| 63 |
+
|
| 64 |
+
prec, rec, _ = precision_recall_curve(y_true, y_score)
|
| 65 |
+
f1_scores = 2 * (prec * rec) / (prec + rec + 1e-8)
|
| 66 |
+
results['image_f1_max'] = np.max(f1_scores)
|
| 67 |
+
|
| 68 |
+
if len(self.pix_labels) > 0:
|
| 69 |
+
pix_true = np.array(self.pix_labels)
|
| 70 |
+
pix_score = np.array(self.pix_preds)
|
| 71 |
+
|
| 72 |
+
results['pixel_auroc'] = roc_auc_score(pix_true, pix_score)
|
| 73 |
+
|
| 74 |
+
prec_p, rec_p, thresholds_p = precision_recall_curve(pix_true, pix_score)
|
| 75 |
+
f1_p = 2 * (prec_p * rec_p) / (prec_p + rec_p + 1e-8)
|
| 76 |
+
best_idx = np.argmax(f1_p)
|
| 77 |
+
best_threshold = thresholds_p[best_idx] if best_idx < len(thresholds_p) else 0.5
|
| 78 |
+
|
| 79 |
+
results['pixel_f1_max'] = np.max(f1_p)
|
| 80 |
+
|
| 81 |
+
if self.compute_pro:
|
| 82 |
+
results['pixel_pro'] = self._compute_pro(best_threshold)
|
| 83 |
+
|
| 84 |
+
return results
|
| 85 |
+
|
| 86 |
+
def _compute_pro(self, threshold):
|
| 87 |
+
|
| 88 |
+
total_pro = 0
|
| 89 |
+
n_defects = 0
|
| 90 |
+
|
| 91 |
+
for i in range(len(self.full_amaps)):
|
| 92 |
+
gt = self.full_masks[i]
|
| 93 |
+
# Skip normal images
|
| 94 |
+
if np.sum(gt) == 0:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
pred_mask = (self.full_amaps[i] >= threshold).astype(int)
|
| 98 |
+
|
| 99 |
+
# Label connected components in Ground Truth
|
| 100 |
+
labeled_gt = label(gt)
|
| 101 |
+
regions = regionprops(labeled_gt)
|
| 102 |
+
|
| 103 |
+
for region in regions:
|
| 104 |
+
n_defects += 1
|
| 105 |
+
blob_mask = (labeled_gt == region.label)
|
| 106 |
+
overlap_pixels = np.sum(pred_mask & blob_mask)
|
| 107 |
+
blob_area = region.area
|
| 108 |
+
|
| 109 |
+
total_pro += (overlap_pixels / blob_area)
|
| 110 |
+
|
| 111 |
+
return total_pro / n_defects if n_defects > 0 else 0.0
|
main.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from segmenters import PCASegmenter
|
| 9 |
+
from segmenters.sam3 import SAM3Segmenter
|
| 10 |
+
from backbones import get_backbone
|
| 11 |
+
from dataset.dataloader import load_mvtec
|
| 12 |
+
from models.model_bank_knn import PatchKNNDetector
|
| 13 |
+
from evaluation.anomaly_evaluator import AnomalyEvaluator
|
| 14 |
+
|
| 15 |
+
def main(
|
| 16 |
+
category: str = "bottle",
|
| 17 |
+
root: str | None = None,
|
| 18 |
+
backbone_name: str = "dinov3_small",
|
| 19 |
+
use_sam3: bool = True,
|
| 20 |
+
use_pca: bool = False,
|
| 21 |
+
pca_backbone_name: str | None = None,
|
| 22 |
+
return_results: bool = False,
|
| 23 |
+
backbone_model=None,
|
| 24 |
+
segmenter_obj=None,
|
| 25 |
+
n_ref: int = 1,
|
| 26 |
+
):
|
| 27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
|
| 29 |
+
# 1. Setup
|
| 30 |
+
root = root or os.path.join("dataset", "mvtec_anomaly_detection")
|
| 31 |
+
train_paths, test_paths = load_mvtec(category=category, root=root)
|
| 32 |
+
test_paths = test_paths[::5]
|
| 33 |
+
train_paths = train_paths[::5]
|
| 34 |
+
print(f"{category}: {len(train_paths)} train, {len(test_paths)} test images")
|
| 35 |
+
|
| 36 |
+
# 2. Initialize Evaluator with PRO support
|
| 37 |
+
evaluator = AnomalyEvaluator(pixel_subsample_rate=0.01, compute_pro=False)
|
| 38 |
+
|
| 39 |
+
# 3. Model Init
|
| 40 |
+
segmenter = segmenter_obj
|
| 41 |
+
if segmenter is None:
|
| 42 |
+
if use_pca:
|
| 43 |
+
pca_backbone = pca_backbone_name or backbone_name
|
| 44 |
+
segmenter = PCASegmenter(backbone_name=pca_backbone, device=device)
|
| 45 |
+
print(f"Using PCA segmenter (backbone={pca_backbone})")
|
| 46 |
+
elif use_sam3:
|
| 47 |
+
segmenter = SAM3Segmenter(text_prompt=category, device=device)
|
| 48 |
+
print("Using SAM3 segmenter")
|
| 49 |
+
else:
|
| 50 |
+
print("No segmenter selected; using full-image foreground.")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
backbone = backbone_model or get_backbone(backbone_name)
|
| 54 |
+
|
| 55 |
+
model = PatchKNNDetector(
|
| 56 |
+
backbone=backbone,
|
| 57 |
+
segmenter=segmenter,
|
| 58 |
+
device=device,
|
| 59 |
+
k_neighbors=1,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
print(f"Fitting model... (backbone={backbone_name}, sam3={use_sam3})")
|
| 63 |
+
model.fit(train_paths, n_ref=n_ref)
|
| 64 |
+
|
| 65 |
+
# 4. Evaluation Loop
|
| 66 |
+
print(f"Starting evaluation on {len(test_paths)} images...")
|
| 67 |
+
|
| 68 |
+
for i, path in enumerate(test_paths):
|
| 69 |
+
# Predict
|
| 70 |
+
image, amap, score = model.predict(path)
|
| 71 |
+
|
| 72 |
+
# Ground Truth Logic
|
| 73 |
+
is_anomaly = 0 if "good" in path else 1
|
| 74 |
+
|
| 75 |
+
if is_anomaly == 0:
|
| 76 |
+
gt_mask = np.zeros_like(amap)
|
| 77 |
+
else:
|
| 78 |
+
mask_path = path.replace("test", "ground_truth").replace(".png", "_mask.png")
|
| 79 |
+
if os.path.exists(mask_path):
|
| 80 |
+
gt_mask = cv2.imread(mask_path, 0)
|
| 81 |
+
if gt_mask.shape != amap.shape:
|
| 82 |
+
gt_mask = cv2.resize(gt_mask, (amap.shape[1], amap.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 83 |
+
gt_mask = (gt_mask > 0).astype(int)
|
| 84 |
+
else:
|
| 85 |
+
gt_mask = np.zeros_like(amap)
|
| 86 |
+
|
| 87 |
+
# Update Evaluator
|
| 88 |
+
evaluator.update(image_score=score, gt_label=is_anomaly, anomaly_map=amap, gt_mask=gt_mask)
|
| 89 |
+
|
| 90 |
+
if i % 20 == 0:
|
| 91 |
+
print(f"Processed {i}/{len(test_paths)}...")
|
| 92 |
+
|
| 93 |
+
# 5. Compute & Print Results
|
| 94 |
+
results = evaluator.compute()
|
| 95 |
+
|
| 96 |
+
print("\n" + "="*40)
|
| 97 |
+
print(f"FINAL RESULTS: {category}")
|
| 98 |
+
print("-" * 40)
|
| 99 |
+
|
| 100 |
+
# Image Level
|
| 101 |
+
print(f"Image AUROC: {results['image_auroc']:.4f}")
|
| 102 |
+
print(f"Image F1-Max: {results['image_f1_max']:.4f}")
|
| 103 |
+
print(f"Image AP: {results['image_ap']:.4f}")
|
| 104 |
+
print("-" * 40)
|
| 105 |
+
|
| 106 |
+
# Pixel Level
|
| 107 |
+
if 'pixel_auroc' in results:
|
| 108 |
+
print(f"Pixel AUROC: {results['pixel_auroc']:.4f}")
|
| 109 |
+
print(f"Pixel F1-Max: {results['pixel_f1_max']:.4f}")
|
| 110 |
+
#print(f"PRO Score: {results['pixel_pro']:.4f}")
|
| 111 |
+
print("="*40)
|
| 112 |
+
|
| 113 |
+
return results if return_results else None
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
category = os.environ.get("MVTec_CATEGORY", "bottle")
|
| 119 |
+
root = os.environ.get("MVTec_ROOT", None)
|
| 120 |
+
backbone_name = os.environ.get("BACKBONE_NAME", "dinov3_small")
|
| 121 |
+
use_sam3_env = os.environ.get("USE_SAM3", "0").lower()
|
| 122 |
+
use_sam3 = use_sam3_env not in {"0", "false", "no"}
|
| 123 |
+
use_pca_env = os.environ.get("USE_PCA", "1").lower()
|
| 124 |
+
use_pca = use_pca_env in {"1", "true", "yes"}
|
| 125 |
+
pca_backbone_name = os.environ.get("PCA_BACKBONE", None)
|
| 126 |
+
|
| 127 |
+
main(
|
| 128 |
+
category=category,
|
| 129 |
+
root=root,
|
| 130 |
+
backbone_name=backbone_name,
|
| 131 |
+
use_sam3=use_sam3,
|
| 132 |
+
use_pca=use_pca,
|
| 133 |
+
pca_backbone_name=pca_backbone_name,
|
| 134 |
+
n_ref=int(os.environ.get("N_REF", 1)),
|
| 135 |
+
)
|
models/__pycache__/model_bank_knn.cpython-312.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
models/model_bank_knn.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Iterable, Tuple, Optional
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from timm.data import resolve_data_config
|
| 10 |
+
import torchvision.transforms as T
|
| 11 |
+
from torchvision.transforms import InterpolationMode
|
| 12 |
+
|
| 13 |
+
from segmenters import BaseSegmenter
|
| 14 |
+
from utils.visualize import visualize_segmentation
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PatchKNNDetector:
|
| 18 |
+
|
| 19 |
+
def __init__(self, backbone, segmenter = None, device = "cuda", k_neighbors = 1):
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
self.device = device
|
| 23 |
+
self.backbone = backbone.to(device)
|
| 24 |
+
#Switch backbone to inference mode
|
| 25 |
+
self.backbone.eval()
|
| 26 |
+
|
| 27 |
+
self.segmenter = segmenter
|
| 28 |
+
self.k_neighbors = k_neighbors
|
| 29 |
+
|
| 30 |
+
# Prepare resize/normalize augmentations shared by DINO and SAM
|
| 31 |
+
data_cfg = resolve_data_config({}, model=self.backbone)
|
| 32 |
+
_, self.img_size, _ = data_cfg["input_size"]
|
| 33 |
+
interp = data_cfg.get("interpolation", "bicubic")
|
| 34 |
+
|
| 35 |
+
self.transform = T.Compose(
|
| 36 |
+
[
|
| 37 |
+
T.Resize(self.img_size, interpolation=getattr(InterpolationMode, interp.upper(), InterpolationMode.BICUBIC)),
|
| 38 |
+
T.ToTensor(),
|
| 39 |
+
T.Normalize(mean=data_cfg.get("mean", (0.485, 0.456, 0.406)),
|
| 40 |
+
std=data_cfg.get("std", (0.229, 0.224, 0.225))),
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.num_register_tokens = getattr(self.backbone, "num_register_tokens", 0)
|
| 45 |
+
|
| 46 |
+
# Memory bank of foreground patch embeddings
|
| 47 |
+
self.memory_bank = None
|
| 48 |
+
self.patch_grid_size = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def fit(self, train_image_paths, n_ref = 1):
|
| 52 |
+
"""Populate memory bank with references """
|
| 53 |
+
|
| 54 |
+
selected_paths = list(train_image_paths)[:n_ref]
|
| 55 |
+
all_patches = []
|
| 56 |
+
|
| 57 |
+
for path in selected_paths:
|
| 58 |
+
|
| 59 |
+
#Extracting features
|
| 60 |
+
image = self._load_image(path)
|
| 61 |
+
patch_feats, grid_size = self._extract_patch_features(image)
|
| 62 |
+
|
| 63 |
+
# Applying foreground mask
|
| 64 |
+
patch_mask = self._compute_patch_mask(image, grid_size)
|
| 65 |
+
patch_feats_fg = patch_feats[patch_mask]
|
| 66 |
+
|
| 67 |
+
all_patches.append(patch_feats_fg)
|
| 68 |
+
self.patch_grid_size = grid_size
|
| 69 |
+
|
| 70 |
+
self.memory_bank = np.concatenate(all_patches, axis=0)
|
| 71 |
+
|
| 72 |
+
def predict(self, image_path) :
|
| 73 |
+
"""
|
| 74 |
+
Run anomaly detection inference
|
| 75 |
+
"""
|
| 76 |
+
image = self._load_image(image_path)
|
| 77 |
+
patch_feats, grid_size = self._extract_patch_features(image)
|
| 78 |
+
|
| 79 |
+
patch_mask = self._compute_patch_mask(image, grid_size)
|
| 80 |
+
|
| 81 |
+
# Compute distances only on foreground patches
|
| 82 |
+
scores_fg = self._knn_distances(patch_feats[patch_mask])
|
| 83 |
+
|
| 84 |
+
# Put scores back into full patch grid
|
| 85 |
+
scores_all = np.zeros(patch_feats.shape[0], dtype=np.float32)
|
| 86 |
+
scores_all[patch_mask] = scores_fg
|
| 87 |
+
patch_map = scores_all.reshape(grid_size)
|
| 88 |
+
|
| 89 |
+
# Upsample to full image ( just for visualization)
|
| 90 |
+
h, w = image.shape[:2]
|
| 91 |
+
anomaly_map = cv2.resize(
|
| 92 |
+
patch_map,
|
| 93 |
+
(w, h),
|
| 94 |
+
interpolation=cv2.INTER_CUBIC,
|
| 95 |
+
).astype(np.float32)
|
| 96 |
+
|
| 97 |
+
# Using image-level score - mean of top 1% patch scores
|
| 98 |
+
image_score = self._mean_top_percent(scores_fg, top_percent=1.0)
|
| 99 |
+
|
| 100 |
+
return image, anomaly_map, image_score
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _load_image(self, path):
|
| 104 |
+
bgr = cv2.imread(path, cv2.IMREAD_COLOR)
|
| 105 |
+
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
| 106 |
+
return rgb
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def _l2_normalize(feats: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
| 110 |
+
norm = np.linalg.norm(feats, axis=1, keepdims=True)
|
| 111 |
+
return feats / np.maximum(norm, eps)
|
| 112 |
+
|
| 113 |
+
def _extract_patch_features(self, image: np.ndarray) :
|
| 114 |
+
"""
|
| 115 |
+
Run backbone on a single image and return patch features
|
| 116 |
+
"""
|
| 117 |
+
pil_resized, _ = self._resize_for_model(image)
|
| 118 |
+
x = self.transform(pil_resized).unsqueeze(0).to(self.device)
|
| 119 |
+
|
| 120 |
+
with torch.inference_mode():
|
| 121 |
+
out = self.backbone.forward_features(x)
|
| 122 |
+
|
| 123 |
+
tokens = out.get("x_norm_patchtokens") if isinstance(out, dict) else out
|
| 124 |
+
|
| 125 |
+
if tokens is None and isinstance(out, dict):
|
| 126 |
+
tokens = out.get("x")
|
| 127 |
+
if tokens is not None and tokens.ndim == 4:
|
| 128 |
+
B, C, Hf, Wf = tokens.shape
|
| 129 |
+
tokens = tokens.permute(0, 2, 3, 1).reshape(B, Hf * Wf, C)
|
| 130 |
+
|
| 131 |
+
B, N, C = tokens.shape
|
| 132 |
+
|
| 133 |
+
if hasattr(self.backbone, "patch_embed") and hasattr(self.backbone.patch_embed, "grid_size"):
|
| 134 |
+
gh, gw = self.backbone.patch_embed.grid_size
|
| 135 |
+
else:
|
| 136 |
+
gh = int(np.sqrt(N))
|
| 137 |
+
gw = max(1, N // max(1, gh))
|
| 138 |
+
|
| 139 |
+
n_patches = gh * gw
|
| 140 |
+
patch_tokens = tokens[:, -n_patches:, :]
|
| 141 |
+
|
| 142 |
+
# Flatten and normalize
|
| 143 |
+
feats = (
|
| 144 |
+
patch_tokens.reshape(B * n_patches, C)
|
| 145 |
+
.detach()
|
| 146 |
+
.cpu()
|
| 147 |
+
.numpy()
|
| 148 |
+
.astype("float32")
|
| 149 |
+
)
|
| 150 |
+
feats = self._l2_normalize(feats)
|
| 151 |
+
|
| 152 |
+
grid_size = (gh, gw)
|
| 153 |
+
return feats, grid_size
|
| 154 |
+
|
| 155 |
+
def _compute_patch_mask(self,image,grid_size) :
|
| 156 |
+
"""
|
| 157 |
+
Convert a pixel-level mask to patch-level mask.
|
| 158 |
+
"""
|
| 159 |
+
h_p, w_p = grid_size
|
| 160 |
+
n_patches = h_p * w_p
|
| 161 |
+
|
| 162 |
+
if self.segmenter is None:
|
| 163 |
+
return np.ones(n_patches, dtype=bool)
|
| 164 |
+
|
| 165 |
+
# Resize image same way as in the backbone before sending to SAM
|
| 166 |
+
pil_resized, np_resized = self._resize_for_model(image)
|
| 167 |
+
full_mask = self.segmenter.get_object_mask(np_resized)
|
| 168 |
+
|
| 169 |
+
# Optionally visualize in resized space
|
| 170 |
+
visualize_segmentation(
|
| 171 |
+
np_resized,
|
| 172 |
+
full_mask,
|
| 173 |
+
grid_size=None,
|
| 174 |
+
title=f"Segmentation debug (resized {self.img_size})",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
full_mask_uint8 = (full_mask.astype(np.uint8) * 255).astype(np.float32)
|
| 178 |
+
|
| 179 |
+
# Downsample to patch grid with area interpolation for coverage
|
| 180 |
+
mask_small = cv2.resize(
|
| 181 |
+
full_mask_uint8,
|
| 182 |
+
(w_p, h_p),
|
| 183 |
+
interpolation=cv2.INTER_AREA,
|
| 184 |
+
) / 255.0
|
| 185 |
+
|
| 186 |
+
patch_mask = (mask_small >= 0.5).reshape(-1)
|
| 187 |
+
|
| 188 |
+
# Fallback if mask collapses
|
| 189 |
+
fg_ratio = patch_mask.mean()
|
| 190 |
+
if fg_ratio < 0.01 or fg_ratio > 0.99:
|
| 191 |
+
patch_mask = np.ones(n_patches, dtype=bool)
|
| 192 |
+
|
| 193 |
+
return patch_mask
|
| 194 |
+
|
| 195 |
+
def _resize_for_model(self, image):
|
| 196 |
+
pil = Image.fromarray(image)
|
| 197 |
+
pil_resized = pil.resize((self.img_size, self.img_size), Image.BICUBIC)
|
| 198 |
+
np_resized = np.array(pil_resized)
|
| 199 |
+
return pil_resized, np_resized
|
| 200 |
+
|
| 201 |
+
def _knn_distances(self, feats: np.ndarray) -> np.ndarray:
|
| 202 |
+
"""
|
| 203 |
+
Compute distance of each query feature to its nearest neighbors in the memory bank.
|
| 204 |
+
|
| 205 |
+
"""
|
| 206 |
+
if self.memory_bank is None:
|
| 207 |
+
raise RuntimeError("Memory bank is empty.")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
a = feats
|
| 211 |
+
b = self.memory_bank
|
| 212 |
+
|
| 213 |
+
# vectorize version of L2 distances
|
| 214 |
+
a2 = np.sum(a**2, axis=1, keepdims=True)
|
| 215 |
+
b2 = np.sum(b**2, axis=1, keepdims=True).T
|
| 216 |
+
ab = a @ b.T
|
| 217 |
+
# Clip to avoid negative values
|
| 218 |
+
d2 = np.clip(a2 + b2 - 2.0 * ab, a_min=0.0, a_max=None)
|
| 219 |
+
d = np.sqrt(d2)
|
| 220 |
+
|
| 221 |
+
# kNN: take mean of k smallest distances per patch
|
| 222 |
+
k = min(self.k_neighbors, d.shape[1])
|
| 223 |
+
if k == 1:
|
| 224 |
+
min_d = d.min(axis=1)
|
| 225 |
+
else:
|
| 226 |
+
# partial sort for efficiency
|
| 227 |
+
part = np.partition(d, kth=k - 1, axis=1)[:, :k]
|
| 228 |
+
min_d = part.mean(axis=1)
|
| 229 |
+
|
| 230 |
+
return min_d.astype(np.float32)
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def _mean_top_percent(values: np.ndarray, top_percent: float = 1.0) -> float:
|
| 234 |
+
"""Mean of top p% values used as image level anomaly score."""
|
| 235 |
+
if values.size == 0:
|
| 236 |
+
return 0.0
|
| 237 |
+
k = max(1, int(round(values.size * (top_percent / 100.0))))
|
| 238 |
+
part = np.partition(values, -k)[-k:]
|
| 239 |
+
return float(part.mean())
|
segmenters/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from segmenters.base_segmenter import BaseSegmenter
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from segmenters.sam3 import SAM3Segmenter
|
| 5 |
+
|
| 6 |
+
from segmenters.pca_segmenter import PCASegmenter
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Model registry for segmenters
|
| 10 |
+
"""
|
| 11 |
+
_SEGMENTERS = {}
|
| 12 |
+
if SAM3Segmenter is not None:
|
| 13 |
+
_SEGMENTERS["sam3"] = SAM3Segmenter
|
| 14 |
+
if PCASegmenter is not None:
|
| 15 |
+
_SEGMENTERS["pca"] = PCASegmenter
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_segmenter(name: str, **kwargs):
|
| 19 |
+
if name not in _SEGMENTERS:
|
| 20 |
+
raise ValueError(f"Unknown segmenter '{name}'. Available: {list(_SEGMENTERS)}")
|
| 21 |
+
return _SEGMENTERS[name](**kwargs)
|
segmenters/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (799 Bytes). View file
|
|
|
segmenters/__pycache__/base_segmenter.cpython-312.pyc
ADDED
|
Binary file (811 Bytes). View file
|
|
|
segmenters/__pycache__/pca_segmenter.cpython-312.pyc
ADDED
|
Binary file (6.98 kB). View file
|
|
|
segmenters/__pycache__/sam3.cpython-312.pyc
ADDED
|
Binary file (4.61 kB). View file
|
|
|
segmenters/base_segmenter.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class BaseSegmenter:
|
| 5 |
+
"""Base class for segmentation models """
|
| 6 |
+
def get_object_mask(self, image: np.ndarray) -> np.ndarray:
|
| 7 |
+
"""
|
| 8 |
+
Args:
|
| 9 |
+
image
|
| 10 |
+
Returns:
|
| 11 |
+
bool mask of shape, where True = foreground object.
|
| 12 |
+
"""
|
| 13 |
+
raise NotImplementedError
|
| 14 |
+
|
segmenters/pca_segmenter.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torchvision.transforms as T
|
| 7 |
+
from timm.data import resolve_data_config
|
| 8 |
+
|
| 9 |
+
from backbones import get_backbone
|
| 10 |
+
from segmenters import BaseSegmenter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PCASegmenter(BaseSegmenter):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
backbone_name: str = "dinov3_base",
|
| 17 |
+
device: str | None = None,
|
| 18 |
+
threshold: float = 2.5,
|
| 19 |
+
kernel_size: int = 5,
|
| 20 |
+
border: float = 0.2,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 24 |
+
self.model = get_backbone(backbone_name).to(self.device)
|
| 25 |
+
self.model.eval()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
cfg = resolve_data_config({}, model=self.model)
|
| 29 |
+
_, img_size, _ = cfg["input_size"]
|
| 30 |
+
arch = getattr(getattr(self.model, "pretrained_cfg", {}), "get", lambda k, d=None: {})( # type: ignore[arg-type]
|
| 31 |
+
"architecture", ""
|
| 32 |
+
)
|
| 33 |
+
if isinstance(arch, str) and "dinov3" in arch:
|
| 34 |
+
img_size = max(img_size, 512)
|
| 35 |
+
|
| 36 |
+
self.img_size = img_size
|
| 37 |
+
interp = cfg.get("interpolation", "bicubic")
|
| 38 |
+
self.transform = T.Compose(
|
| 39 |
+
[
|
| 40 |
+
T.Resize((self.img_size, self.img_size), interpolation=getattr(T.InterpolationMode, interp.upper(), T.InterpolationMode.BICUBIC)),
|
| 41 |
+
T.ToTensor(),
|
| 42 |
+
T.Normalize(mean=cfg.get("mean", (0.485, 0.456, 0.406)), std=cfg.get("std", (0.229, 0.224, 0.225))),
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
self.threshold = threshold
|
| 46 |
+
self.border = border
|
| 47 |
+
self.kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
|
| 48 |
+
|
| 49 |
+
def get_object_mask(self, image: np.ndarray) -> np.ndarray:
|
| 50 |
+
h0, w0 = image.shape[:2]
|
| 51 |
+
pil = Image.fromarray(image.astype(np.uint8))
|
| 52 |
+
x = self.transform(pil).unsqueeze(0).to(self.device)
|
| 53 |
+
|
| 54 |
+
with torch.inference_mode():
|
| 55 |
+
out = self.model.forward_features(x)
|
| 56 |
+
|
| 57 |
+
tokens = out.get("x_norm_patchtokens") if isinstance(out, dict) else out
|
| 58 |
+
|
| 59 |
+
if tokens is None and isinstance(out, dict):
|
| 60 |
+
tokens = out.get("x")
|
| 61 |
+
if tokens is not None and tokens.ndim == 4:
|
| 62 |
+
B, C, Hf, Wf = tokens.shape
|
| 63 |
+
tokens = tokens.permute(0, 2, 3, 1).reshape(B, Hf * Wf, C)
|
| 64 |
+
|
| 65 |
+
gh_dyn = int(np.sqrt(tokens.shape[1]))
|
| 66 |
+
gw_dyn = max(1, tokens.shape[1] // max(1, gh_dyn))
|
| 67 |
+
gh, gw = gh_dyn, gw_dyn
|
| 68 |
+
|
| 69 |
+
if hasattr(self.model, "patch_embed") and hasattr(self.model.patch_embed, "grid_size"):
|
| 70 |
+
gh0, gw0 = self.model.patch_embed.grid_size
|
| 71 |
+
if gh0 * gw0 == tokens.shape[1]:
|
| 72 |
+
gh, gw = gh0, gw0
|
| 73 |
+
n_patches = gh * gw
|
| 74 |
+
tokens = tokens[:, -n_patches:, :]
|
| 75 |
+
|
| 76 |
+
feats = tokens.squeeze(0).detach().cpu().numpy().astype(np.float32)
|
| 77 |
+
feats -= feats.mean(0, keepdims=True)
|
| 78 |
+
u, s, vh = np.linalg.svd(feats, full_matrices=False)
|
| 79 |
+
pc1 = vh[0]
|
| 80 |
+
scores = feats @ pc1
|
| 81 |
+
mask = scores > self.threshold
|
| 82 |
+
m_grid = mask.reshape(gh, gw)
|
| 83 |
+
bh = int(gh * self.border)
|
| 84 |
+
bw = int(gw * self.border)
|
| 85 |
+
inner = m_grid[bh : gh - bh, bw : gw - bw]
|
| 86 |
+
if inner.size > 0 and inner.mean() <= 0.35:
|
| 87 |
+
mask = scores < -self.threshold
|
| 88 |
+
m_grid = mask.reshape(gh, gw)
|
| 89 |
+
mask = m_grid.astype(np.uint8)
|
| 90 |
+
mask = cv2.dilate(mask, self.kernel, iterations=1)
|
| 91 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, self.kernel)
|
| 92 |
+
mask = cv2.resize(mask, (w0, h0), interpolation=cv2.INTER_NEAREST)
|
| 93 |
+
return mask.astype(bool)
|
segmenters/sam2.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from transformers import Sam2Processor, Sam2Model
|
| 6 |
+
|
| 7 |
+
from segmenters import BaseSegmenter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SAM2Segmenter(BaseSegmenter):
|
| 11 |
+
"""
|
| 12 |
+
SAM2 wrapper.
|
| 13 |
+
|
| 14 |
+
- Uses Sam2Model (e.g. `facebook/sam2.1-hiera-large`).
|
| 15 |
+
- Segments (approximately) all objects in the image by prompting
|
| 16 |
+
with a full-image bounding box and returns a single boolean mask
|
| 17 |
+
given by the union of all predicted masks.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
text_prompt: str | None = None,
|
| 23 |
+
model_name: str = "facebook/sam2.1-hiera-large",
|
| 24 |
+
device: str = "cuda",
|
| 25 |
+
mask_threshold: float = 0.5,
|
| 26 |
+
) -> None:
|
| 27 |
+
"""
|
| 28 |
+
Args:
|
| 29 |
+
text_prompt: kept for compatibility with SAM3Segmenter, but unused.
|
| 30 |
+
model_name: HF repo id for the SAM2 model, e.g. "facebook/sam2.1-hiera-large".
|
| 31 |
+
device: "cuda" or "cpu".
|
| 32 |
+
mask_threshold: pixel threshold for masks (after SAM2 post-processing).
|
| 33 |
+
"""
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
if torch.cuda.is_available() and device.startswith("cuda"):
|
| 37 |
+
self.device = torch.device(device)
|
| 38 |
+
else:
|
| 39 |
+
self.device = torch.device("cpu")
|
| 40 |
+
|
| 41 |
+
# Load SAM2 model + processor
|
| 42 |
+
self.model = Sam2Model.from_pretrained(model_name).to(self.device)
|
| 43 |
+
self.model.eval()
|
| 44 |
+
self.processor = Sam2Processor.from_pretrained(model_name)
|
| 45 |
+
|
| 46 |
+
def get_object_mask(self, image: np.ndarray) -> np.ndarray:
|
| 47 |
+
"""
|
| 48 |
+
Run SAM2 and return a single foreground mask.
|
| 49 |
+
|
| 50 |
+
- Convert image to PIL.
|
| 51 |
+
- Use a single bounding box covering the whole image as prompt.
|
| 52 |
+
- Run SAM2, post-process masks to image resolution.
|
| 53 |
+
- Threshold and union all masks into one boolean (H, W) array.
|
| 54 |
+
|
| 55 |
+
"""
|
| 56 |
+
# Ensure PIL image
|
| 57 |
+
if isinstance(image, np.ndarray):
|
| 58 |
+
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
|
| 59 |
+
else:
|
| 60 |
+
pil_image = image
|
| 61 |
+
|
| 62 |
+
W, H = pil_image.size # PIL: (W, H)
|
| 63 |
+
|
| 64 |
+
# Full image bounding box: [x_min, y_min, x_max, y_max]
|
| 65 |
+
input_boxes = [[[0, 0, W, H]]]
|
| 66 |
+
|
| 67 |
+
# Build inputs for SAM2
|
| 68 |
+
inputs = self.processor(
|
| 69 |
+
images=pil_image,
|
| 70 |
+
input_boxes=input_boxes,
|
| 71 |
+
return_tensors="pt",
|
| 72 |
+
).to(self.device)
|
| 73 |
+
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
# multimask_output=False → one mask per box
|
| 76 |
+
outputs = self.model(**inputs, multimask_output=False)
|
| 77 |
+
|
| 78 |
+
# Post-process masks to original resolution
|
| 79 |
+
masks = self.processor.post_process_masks(
|
| 80 |
+
outputs.pred_masks.cpu(), # (B, num_masks, H', W')
|
| 81 |
+
inputs["original_sizes"],
|
| 82 |
+
)[0]
|
| 83 |
+
|
| 84 |
+
# Shapes can be:
|
| 85 |
+
# - (num_masks, H, W)
|
| 86 |
+
# - or (1, num_masks, H, W) depending on version
|
| 87 |
+
if masks.ndim == 4:
|
| 88 |
+
# (B, num_masks, H, W) -> (num_masks, H, W) for B=1
|
| 89 |
+
masks = masks[0]
|
| 90 |
+
|
| 91 |
+
if masks.ndim == 2:
|
| 92 |
+
# Single mask: (H, W)
|
| 93 |
+
full_mask = (masks > self.mask_threshold).numpy().astype(bool)
|
| 94 |
+
return full_mask
|
| 95 |
+
|
| 96 |
+
if masks.ndim != 3:
|
| 97 |
+
# Failsafe: if something weird happens, keep everything
|
| 98 |
+
return np.ones((H, W), dtype=bool)
|
| 99 |
+
|
| 100 |
+
# masks: (num_masks, H, W)
|
| 101 |
+
masks_bin = masks > self.mask_threshold
|
| 102 |
+
combined = masks_bin.any(dim=0) # (H, W)
|
| 103 |
+
full_mask = combined.numpy().astype(bool)
|
| 104 |
+
|
| 105 |
+
return full_mask
|
segmenters/sam3.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from transformers import Sam3Processor, Sam3Model
|
| 7 |
+
|
| 8 |
+
from segmenters import BaseSegmenter
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SAM3Segmenter(BaseSegmenter):
|
| 12 |
+
"""
|
| 13 |
+
SAM3 wrapper using a text prompt of object type
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
text_prompt: str,
|
| 19 |
+
model_name: str = "facebook/sam3",
|
| 20 |
+
device: str = "cuda",
|
| 21 |
+
score_threshold: float = 0.5,
|
| 22 |
+
mask_threshold: float = 0.5 ):
|
| 23 |
+
"""
|
| 24 |
+
Args:
|
| 25 |
+
text_prompt: stuff we want to segment.
|
| 26 |
+
model_name: HF repo id for the SAM3 model.
|
| 27 |
+
device: "cuda" or "cpu".
|
| 28 |
+
score_threshold: min detection score to keep an instance.
|
| 29 |
+
mask_threshold: pixel threshold for masks.
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
if torch.cuda.is_available() and device.startswith("cuda"):
|
| 34 |
+
self.device = torch.device(device)
|
| 35 |
+
else:
|
| 36 |
+
self.device = torch.device("cpu")
|
| 37 |
+
|
| 38 |
+
# preprocess text prompt so metal_nut is processed as metal nut
|
| 39 |
+
self.text_prompt = text_prompt.replace("_", " ")
|
| 40 |
+
self.score_threshold = score_threshold
|
| 41 |
+
self.mask_threshold = mask_threshold
|
| 42 |
+
|
| 43 |
+
# Loading model model + defining processor
|
| 44 |
+
self.model = Sam3Model.from_pretrained(model_name).to(self.device)
|
| 45 |
+
self.model.eval()
|
| 46 |
+
self.processor = Sam3Processor.from_pretrained(model_name)
|
| 47 |
+
|
| 48 |
+
def get_object_mask(self, image: np.ndarray) -> np.ndarray:
|
| 49 |
+
"""
|
| 50 |
+
Running SAM3 and returning a single foreground mask.
|
| 51 |
+
"""
|
| 52 |
+
# Pill image stuff - probably there is less idiotic way, but it is wat ChatGPT suggested
|
| 53 |
+
if isinstance(image, np.ndarray):
|
| 54 |
+
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
|
| 55 |
+
else:
|
| 56 |
+
pil_image = image
|
| 57 |
+
|
| 58 |
+
# defining preprocessor with text prompt
|
| 59 |
+
inputs = self.processor(
|
| 60 |
+
images=pil_image,
|
| 61 |
+
text=self.text_prompt,
|
| 62 |
+
return_tensors="pt",
|
| 63 |
+
).to(self.device)
|
| 64 |
+
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
outputs = self.model(**inputs)
|
| 67 |
+
|
| 68 |
+
# Post-process instance segmentation
|
| 69 |
+
target_sizes = inputs.get("original_sizes").tolist()
|
| 70 |
+
results = self.processor.post_process_instance_segmentation(
|
| 71 |
+
outputs,
|
| 72 |
+
threshold=self.score_threshold,
|
| 73 |
+
mask_threshold=self.mask_threshold,
|
| 74 |
+
target_sizes=target_sizes,
|
| 75 |
+
)[0]
|
| 76 |
+
|
| 77 |
+
masks = results.get("masks", None)
|
| 78 |
+
scores = results.get("scores", None)
|
| 79 |
+
|
| 80 |
+
# If SAM completely fails we keep everything
|
| 81 |
+
if masks is None or masks.numel() == 0:
|
| 82 |
+
H, W = pil_image.size[1], pil_image.size[0]
|
| 83 |
+
return np.ones((H, W), dtype=bool)
|
| 84 |
+
|
| 85 |
+
if scores is not None:
|
| 86 |
+
keep = scores >= self.score_threshold
|
| 87 |
+
if keep.sum() == 0:
|
| 88 |
+
H, W = pil_image.size[1], pil_image.size[0]
|
| 89 |
+
return np.ones((H, W), dtype=bool)
|
| 90 |
+
masks = masks[keep]
|
| 91 |
+
|
| 92 |
+
# check if mask passes mask treshold
|
| 93 |
+
masks_bin = (masks > self.mask_threshold)
|
| 94 |
+
combined = masks_bin.any(dim=0)
|
| 95 |
+
full_mask = combined.cpu().numpy().astype(bool)
|
| 96 |
+
|
| 97 |
+
return full_mask
|
utils/__pycache__/visualize.cpython-312.pyc
ADDED
|
Binary file (5.12 kB). View file
|
|
|
utils/visualize.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import cv2
|
| 5 |
+
|
| 6 |
+
#Completeöy vibecoded
|
| 7 |
+
def visualize_prediction(
|
| 8 |
+
image: np.ndarray,
|
| 9 |
+
anomaly_map: np.ndarray,
|
| 10 |
+
image_score: float,
|
| 11 |
+
threshold_percentile: float = 95.0,
|
| 12 |
+
title: str | None = None,
|
| 13 |
+
) -> None:
|
| 14 |
+
"""
|
| 15 |
+
Show:
|
| 16 |
+
- original image
|
| 17 |
+
- heatmap overlay
|
| 18 |
+
- binary mask overlay (thresholded)
|
| 19 |
+
"""
|
| 20 |
+
# Normalize anomaly map to [0, 1] for visualization
|
| 21 |
+
amap = anomaly_map.astype(np.float32)
|
| 22 |
+
amap -= amap.min()
|
| 23 |
+
if amap.max() > 0:
|
| 24 |
+
amap /= amap.max()
|
| 25 |
+
|
| 26 |
+
thresh = np.percentile(amap, threshold_percentile)
|
| 27 |
+
binary = amap >= thresh
|
| 28 |
+
|
| 29 |
+
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
|
| 30 |
+
|
| 31 |
+
axes[0].imshow(image)
|
| 32 |
+
axes[0].set_title("Input image")
|
| 33 |
+
axes[0].axis("off")
|
| 34 |
+
|
| 35 |
+
axes[1].imshow(image)
|
| 36 |
+
im = axes[1].imshow(amap, cmap="jet", alpha=0.5)
|
| 37 |
+
axes[1].set_title("Anomaly heatmap")
|
| 38 |
+
axes[1].axis("off")
|
| 39 |
+
fig.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
|
| 40 |
+
|
| 41 |
+
axes[2].imshow(image)
|
| 42 |
+
axes[2].imshow(binary, cmap="gray", alpha=0.5)
|
| 43 |
+
axes[2].set_title(f"Thresholded (>{threshold_percentile:.0f}%)")
|
| 44 |
+
axes[2].axis("off")
|
| 45 |
+
|
| 46 |
+
if title is None:
|
| 47 |
+
title = f"Image anomaly score: {image_score:.3f}"
|
| 48 |
+
|
| 49 |
+
fig.suptitle(title)
|
| 50 |
+
plt.tight_layout()
|
| 51 |
+
plt.show()
|
| 52 |
+
|
| 53 |
+
def visualize_segmentation(
|
| 54 |
+
image: np.ndarray,
|
| 55 |
+
full_mask: np.ndarray,
|
| 56 |
+
grid_size: tuple[int, int] | None = None,
|
| 57 |
+
title: str | None = None,
|
| 58 |
+
) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Visualize SAM2 segmentation.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
image: (H, W, 3) RGB uint8
|
| 64 |
+
full_mask: (H, W) bool or 0/1 array from SAM2
|
| 65 |
+
grid_size: optional (H_patches, W_patches) to also show patch-level mask
|
| 66 |
+
title: optional title string
|
| 67 |
+
"""
|
| 68 |
+
img = image
|
| 69 |
+
mask = full_mask.astype(bool)
|
| 70 |
+
H, W = mask.shape
|
| 71 |
+
|
| 72 |
+
# Prepare figure layout
|
| 73 |
+
n_cols = 3 if grid_size is None else 4
|
| 74 |
+
fig, axes = plt.subplots(1, n_cols, figsize=(4 * n_cols, 4))
|
| 75 |
+
|
| 76 |
+
# 1) input image
|
| 77 |
+
axes[0].imshow(img)
|
| 78 |
+
axes[0].set_title("Input image")
|
| 79 |
+
axes[0].axis("off")
|
| 80 |
+
|
| 81 |
+
# 2) raw binary mask
|
| 82 |
+
axes[1].imshow(mask, cmap="gray")
|
| 83 |
+
axes[1].set_title("SAM2 mask (full-res)")
|
| 84 |
+
axes[1].axis("off")
|
| 85 |
+
|
| 86 |
+
# 3) overlay mask on image
|
| 87 |
+
axes[2].imshow(img)
|
| 88 |
+
axes[2].imshow(mask, cmap="Reds", alpha=0.4)
|
| 89 |
+
axes[2].set_title("Mask overlay")
|
| 90 |
+
axes[2].axis("off")
|
| 91 |
+
|
| 92 |
+
# 4) optional patch-level mask (after downsampling)
|
| 93 |
+
if grid_size is not None:
|
| 94 |
+
gh, gw = grid_size
|
| 95 |
+
# downsample full mask to patch grid and back up to image size
|
| 96 |
+
patch_mask_small = cv2.resize(
|
| 97 |
+
mask.astype(np.uint8), (gw, gh), interpolation=cv2.INTER_NEAREST
|
| 98 |
+
).astype(bool)
|
| 99 |
+
patch_mask_full = cv2.resize(
|
| 100 |
+
patch_mask_small.astype(np.uint8),
|
| 101 |
+
(W, H),
|
| 102 |
+
interpolation=cv2.INTER_NEAREST,
|
| 103 |
+
).astype(bool)
|
| 104 |
+
|
| 105 |
+
axes[3].imshow(img)
|
| 106 |
+
axes[3].imshow(patch_mask_full, cmap="Blues", alpha=0.4)
|
| 107 |
+
axes[3].set_title("Patch-level mask (after downsample)")
|
| 108 |
+
axes[3].axis("off")
|
| 109 |
+
|
| 110 |
+
if title is not None:
|
| 111 |
+
fig.suptitle(title)
|
| 112 |
+
|
| 113 |
+
plt.tight_layout()
|
| 114 |
+
plt.show()
|