V4ldeLund commited on
Commit
1834bc0
·
verified ·
1 Parent(s): 0633258

Upload full code for Space

Browse files
README.md CHANGED
@@ -1,14 +1 @@
1
- ---
2
- title: AnomalyDetection
3
- emoji: 📈
4
- colorFrom: green
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 6.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Anomaly detection using Dino v3
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ## To run and reproduce the PatchKNN model, use interactive demo notebook
 
 
 
 
 
 
 
 
 
 
 
 
 
backbones/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+
3
+ from backbones.dino_v2 import build_dinov2_small, build_dinov2_base, build_dinov2_large
4
+ from backbones.dino_v3 import build_dinov3_small, build_dinov3_base, build_dinov3_large
5
+
6
+ """
7
+ Model registry for backbones
8
+ """
9
+
10
+
11
+ _BACKBONES = {
12
+ "dinov2_small": build_dinov2_small,
13
+ "dinov2_base": build_dinov2_base,
14
+ "dinov2_large": build_dinov2_large,
15
+ "dinov3_small": build_dinov3_small,
16
+ "dinov3_base": build_dinov3_base,
17
+ "dinov3_large": build_dinov3_large,
18
+
19
+ }
20
+
21
+ def get_backbone(name: str, **kwargs):
22
+ if name not in _BACKBONES:
23
+ raise ValueError(f"Unknown backbone '{name}'. Available: {list(_BACKBONES)}")
24
+ return _BACKBONES[name](**kwargs)
backbones/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (919 Bytes). View file
 
backbones/__pycache__/dino_v2.cpython-312.pyc ADDED
Binary file (940 Bytes). View file
 
backbones/__pycache__/dino_v3.cpython-312.pyc ADDED
Binary file (1 kB). View file
 
backbones/dino_v2.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import timm
3
+ """
4
+ DinoV2 backbones
5
+ """
6
+ def build_dinov2_small(**kwargs):
7
+
8
+ model = timm.create_model('vit_small_patch14_reg4_dinov2.lvd142m',
9
+ pretrained=True, num_classes=0)
10
+ return model
11
+
12
+ def build_dinov2_base(**kwargs):
13
+
14
+ model = timm.create_model('vit_base_patch14_reg4_dinov2.lvd142m',
15
+ pretrained=True, num_classes=0)
16
+ return model
17
+
18
+ def build_dinov2_large(**kwargs):
19
+
20
+ model = timm.create_model('vit_large_patch14_dinov2.lvd142m',
21
+ pretrained=True, num_classes=0)
22
+ return model
backbones/dino_v3.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ """
3
+ DinoV3 backbones
4
+ """
5
+
6
+ def build_dinov3_small(**kwargs):
7
+ model = timm.create_model(
8
+ "vit_small_plus_patch16_dinov3.lvd1689m",
9
+ pretrained=True,
10
+ num_classes=0,
11
+ **kwargs,
12
+ )
13
+ return model
14
+
15
+
16
+ def build_dinov3_base(**kwargs):
17
+ model = timm.create_model(
18
+ "vit_base_patch16_dinov3.lvd1689m",
19
+ pretrained=True,
20
+ num_classes=0,
21
+ **kwargs,
22
+ )
23
+ return model
24
+
25
+
26
+ def build_dinov3_large(**kwargs):
27
+ model = timm.create_model(
28
+ "vit_large_patch16_dinov3.lvd1689m",
29
+ pretrained=True,
30
+ num_classes=0,
31
+ **kwargs,
32
+ )
33
+ return model
dataset/__pycache__/dataloader.cpython-312.pyc ADDED
Binary file (921 Bytes). View file
 
dataset/dataloader.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Tuple
3
+
4
+ from anomalib.data import MVTec, MVTecAD
5
+ from anomalib.data.datasets.image import MVTecDataset
6
+ """Loading logic :
7
+ Download MVTech AD dataset from the website and place it in data folder (Anomalib installation doesnt work, sorry you have to donwload manually)
8
+
9
+ """
10
+
11
+
12
+ def load_mvtec(category: str, root: str = "./datasets/MVTec",) :
13
+ train_ds = MVTecDataset(
14
+ root=root,
15
+ category=category,
16
+ split="train"
17
+ )
18
+
19
+ test_ds = MVTecDataset(
20
+ root=root,
21
+ category=category,
22
+ split="test"
23
+ )
24
+
25
+ train_paths = train_ds.samples["image_path"].tolist()
26
+ test_paths = test_ds.samples["image_path"].tolist()
27
+
28
+ return train_paths, test_paths
evaluation/__pycache__/anomaly_evaluator.cpython-312.pyc ADDED
Binary file (5.71 kB). View file
 
evaluation/anomaly_evaluator.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score
4
+ from skimage.measure import label, regionprops
5
+
6
+ class AnomalyEvaluator:
7
+ def __init__(self, pixel_subsample_rate=0.01, compute_pro=False):
8
+ self.subsample_rate = pixel_subsample_rate
9
+ self.compute_pro = compute_pro
10
+ self.reset()
11
+
12
+ def reset(self):
13
+ self.img_preds = []
14
+ self.img_labels = []
15
+
16
+ self.pix_preds = []
17
+ self.pix_labels = []
18
+
19
+ self.full_amaps = []
20
+ self.full_masks = []
21
+
22
+ def update(self, image_score, gt_label, anomaly_map=None, gt_mask=None):
23
+
24
+
25
+ self.img_preds.append(image_score)
26
+ self.img_labels.append(gt_label)
27
+
28
+ if anomaly_map is not None and gt_mask is not None:
29
+ self._update_pixel_metrics(anomaly_map, gt_mask)
30
+
31
+ def _update_pixel_metrics(self, amap, mask):
32
+ if mask.shape != amap.shape:
33
+ mask = cv2.resize(mask, (amap.shape[1], amap.shape[0]), interpolation=cv2.INTER_NEAREST)
34
+
35
+ mask = (mask > 0).astype(int)
36
+
37
+ if self.compute_pro:
38
+ self.full_amaps.append(amap)
39
+ self.full_masks.append(mask)
40
+
41
+ flat_amap = amap.flatten()
42
+ flat_mask = mask.flatten()
43
+
44
+ if self.compute_pro or self.subsample_rate >= 1.0:
45
+ self.pix_preds.extend(flat_amap)
46
+ self.pix_labels.extend(flat_mask)
47
+ else:
48
+ # Random Subsampling to save memory
49
+ num_pixels = len(flat_mask)
50
+ sample_size = int(num_pixels * self.subsample_rate)
51
+ indices = np.random.choice(num_pixels, sample_size, replace=False)
52
+ self.pix_preds.extend(flat_amap[indices])
53
+ self.pix_labels.extend(flat_mask[indices])
54
+
55
+ def compute(self):
56
+ results = {}
57
+ y_true = np.array(self.img_labels)
58
+ y_score = np.array(self.img_preds)
59
+
60
+ results['image_auroc'] = roc_auc_score(y_true, y_score)
61
+
62
+ results['image_ap'] = average_precision_score(y_true, y_score)
63
+
64
+ prec, rec, _ = precision_recall_curve(y_true, y_score)
65
+ f1_scores = 2 * (prec * rec) / (prec + rec + 1e-8)
66
+ results['image_f1_max'] = np.max(f1_scores)
67
+
68
+ if len(self.pix_labels) > 0:
69
+ pix_true = np.array(self.pix_labels)
70
+ pix_score = np.array(self.pix_preds)
71
+
72
+ results['pixel_auroc'] = roc_auc_score(pix_true, pix_score)
73
+
74
+ prec_p, rec_p, thresholds_p = precision_recall_curve(pix_true, pix_score)
75
+ f1_p = 2 * (prec_p * rec_p) / (prec_p + rec_p + 1e-8)
76
+ best_idx = np.argmax(f1_p)
77
+ best_threshold = thresholds_p[best_idx] if best_idx < len(thresholds_p) else 0.5
78
+
79
+ results['pixel_f1_max'] = np.max(f1_p)
80
+
81
+ if self.compute_pro:
82
+ results['pixel_pro'] = self._compute_pro(best_threshold)
83
+
84
+ return results
85
+
86
+ def _compute_pro(self, threshold):
87
+
88
+ total_pro = 0
89
+ n_defects = 0
90
+
91
+ for i in range(len(self.full_amaps)):
92
+ gt = self.full_masks[i]
93
+ # Skip normal images
94
+ if np.sum(gt) == 0:
95
+ continue
96
+
97
+ pred_mask = (self.full_amaps[i] >= threshold).astype(int)
98
+
99
+ # Label connected components in Ground Truth
100
+ labeled_gt = label(gt)
101
+ regions = regionprops(labeled_gt)
102
+
103
+ for region in regions:
104
+ n_defects += 1
105
+ blob_mask = (labeled_gt == region.label)
106
+ overlap_pixels = np.sum(pred_mask & blob_mask)
107
+ blob_area = region.area
108
+
109
+ total_pro += (overlap_pixels / blob_area)
110
+
111
+ return total_pro / n_defects if n_defects > 0 else 0.0
main.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+
8
+ from segmenters import PCASegmenter
9
+ from segmenters.sam3 import SAM3Segmenter
10
+ from backbones import get_backbone
11
+ from dataset.dataloader import load_mvtec
12
+ from models.model_bank_knn import PatchKNNDetector
13
+ from evaluation.anomaly_evaluator import AnomalyEvaluator
14
+
15
+ def main(
16
+ category: str = "bottle",
17
+ root: str | None = None,
18
+ backbone_name: str = "dinov3_small",
19
+ use_sam3: bool = True,
20
+ use_pca: bool = False,
21
+ pca_backbone_name: str | None = None,
22
+ return_results: bool = False,
23
+ backbone_model=None,
24
+ segmenter_obj=None,
25
+ n_ref: int = 1,
26
+ ):
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ # 1. Setup
30
+ root = root or os.path.join("dataset", "mvtec_anomaly_detection")
31
+ train_paths, test_paths = load_mvtec(category=category, root=root)
32
+ test_paths = test_paths[::5]
33
+ train_paths = train_paths[::5]
34
+ print(f"{category}: {len(train_paths)} train, {len(test_paths)} test images")
35
+
36
+ # 2. Initialize Evaluator with PRO support
37
+ evaluator = AnomalyEvaluator(pixel_subsample_rate=0.01, compute_pro=False)
38
+
39
+ # 3. Model Init
40
+ segmenter = segmenter_obj
41
+ if segmenter is None:
42
+ if use_pca:
43
+ pca_backbone = pca_backbone_name or backbone_name
44
+ segmenter = PCASegmenter(backbone_name=pca_backbone, device=device)
45
+ print(f"Using PCA segmenter (backbone={pca_backbone})")
46
+ elif use_sam3:
47
+ segmenter = SAM3Segmenter(text_prompt=category, device=device)
48
+ print("Using SAM3 segmenter")
49
+ else:
50
+ print("No segmenter selected; using full-image foreground.")
51
+
52
+
53
+ backbone = backbone_model or get_backbone(backbone_name)
54
+
55
+ model = PatchKNNDetector(
56
+ backbone=backbone,
57
+ segmenter=segmenter,
58
+ device=device,
59
+ k_neighbors=1,
60
+ )
61
+
62
+ print(f"Fitting model... (backbone={backbone_name}, sam3={use_sam3})")
63
+ model.fit(train_paths, n_ref=n_ref)
64
+
65
+ # 4. Evaluation Loop
66
+ print(f"Starting evaluation on {len(test_paths)} images...")
67
+
68
+ for i, path in enumerate(test_paths):
69
+ # Predict
70
+ image, amap, score = model.predict(path)
71
+
72
+ # Ground Truth Logic
73
+ is_anomaly = 0 if "good" in path else 1
74
+
75
+ if is_anomaly == 0:
76
+ gt_mask = np.zeros_like(amap)
77
+ else:
78
+ mask_path = path.replace("test", "ground_truth").replace(".png", "_mask.png")
79
+ if os.path.exists(mask_path):
80
+ gt_mask = cv2.imread(mask_path, 0)
81
+ if gt_mask.shape != amap.shape:
82
+ gt_mask = cv2.resize(gt_mask, (amap.shape[1], amap.shape[0]), interpolation=cv2.INTER_NEAREST)
83
+ gt_mask = (gt_mask > 0).astype(int)
84
+ else:
85
+ gt_mask = np.zeros_like(amap)
86
+
87
+ # Update Evaluator
88
+ evaluator.update(image_score=score, gt_label=is_anomaly, anomaly_map=amap, gt_mask=gt_mask)
89
+
90
+ if i % 20 == 0:
91
+ print(f"Processed {i}/{len(test_paths)}...")
92
+
93
+ # 5. Compute & Print Results
94
+ results = evaluator.compute()
95
+
96
+ print("\n" + "="*40)
97
+ print(f"FINAL RESULTS: {category}")
98
+ print("-" * 40)
99
+
100
+ # Image Level
101
+ print(f"Image AUROC: {results['image_auroc']:.4f}")
102
+ print(f"Image F1-Max: {results['image_f1_max']:.4f}")
103
+ print(f"Image AP: {results['image_ap']:.4f}")
104
+ print("-" * 40)
105
+
106
+ # Pixel Level
107
+ if 'pixel_auroc' in results:
108
+ print(f"Pixel AUROC: {results['pixel_auroc']:.4f}")
109
+ print(f"Pixel F1-Max: {results['pixel_f1_max']:.4f}")
110
+ #print(f"PRO Score: {results['pixel_pro']:.4f}")
111
+ print("="*40)
112
+
113
+ return results if return_results else None
114
+
115
+ if __name__ == "__main__":
116
+
117
+
118
+ category = os.environ.get("MVTec_CATEGORY", "bottle")
119
+ root = os.environ.get("MVTec_ROOT", None)
120
+ backbone_name = os.environ.get("BACKBONE_NAME", "dinov3_small")
121
+ use_sam3_env = os.environ.get("USE_SAM3", "0").lower()
122
+ use_sam3 = use_sam3_env not in {"0", "false", "no"}
123
+ use_pca_env = os.environ.get("USE_PCA", "1").lower()
124
+ use_pca = use_pca_env in {"1", "true", "yes"}
125
+ pca_backbone_name = os.environ.get("PCA_BACKBONE", None)
126
+
127
+ main(
128
+ category=category,
129
+ root=root,
130
+ backbone_name=backbone_name,
131
+ use_sam3=use_sam3,
132
+ use_pca=use_pca,
133
+ pca_backbone_name=pca_backbone_name,
134
+ n_ref=int(os.environ.get("N_REF", 1)),
135
+ )
models/__pycache__/model_bank_knn.cpython-312.pyc ADDED
Binary file (11.5 kB). View file
 
models/model_bank_knn.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Iterable, Tuple, Optional
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ import torch
9
+ from timm.data import resolve_data_config
10
+ import torchvision.transforms as T
11
+ from torchvision.transforms import InterpolationMode
12
+
13
+ from segmenters import BaseSegmenter
14
+ from utils.visualize import visualize_segmentation
15
+
16
+
17
+ class PatchKNNDetector:
18
+
19
+ def __init__(self, backbone, segmenter = None, device = "cuda", k_neighbors = 1):
20
+
21
+
22
+ self.device = device
23
+ self.backbone = backbone.to(device)
24
+ #Switch backbone to inference mode
25
+ self.backbone.eval()
26
+
27
+ self.segmenter = segmenter
28
+ self.k_neighbors = k_neighbors
29
+
30
+ # Prepare resize/normalize augmentations shared by DINO and SAM
31
+ data_cfg = resolve_data_config({}, model=self.backbone)
32
+ _, self.img_size, _ = data_cfg["input_size"]
33
+ interp = data_cfg.get("interpolation", "bicubic")
34
+
35
+ self.transform = T.Compose(
36
+ [
37
+ T.Resize(self.img_size, interpolation=getattr(InterpolationMode, interp.upper(), InterpolationMode.BICUBIC)),
38
+ T.ToTensor(),
39
+ T.Normalize(mean=data_cfg.get("mean", (0.485, 0.456, 0.406)),
40
+ std=data_cfg.get("std", (0.229, 0.224, 0.225))),
41
+ ]
42
+ )
43
+
44
+ self.num_register_tokens = getattr(self.backbone, "num_register_tokens", 0)
45
+
46
+ # Memory bank of foreground patch embeddings
47
+ self.memory_bank = None
48
+ self.patch_grid_size = None
49
+
50
+
51
+ def fit(self, train_image_paths, n_ref = 1):
52
+ """Populate memory bank with references """
53
+
54
+ selected_paths = list(train_image_paths)[:n_ref]
55
+ all_patches = []
56
+
57
+ for path in selected_paths:
58
+
59
+ #Extracting features
60
+ image = self._load_image(path)
61
+ patch_feats, grid_size = self._extract_patch_features(image)
62
+
63
+ # Applying foreground mask
64
+ patch_mask = self._compute_patch_mask(image, grid_size)
65
+ patch_feats_fg = patch_feats[patch_mask]
66
+
67
+ all_patches.append(patch_feats_fg)
68
+ self.patch_grid_size = grid_size
69
+
70
+ self.memory_bank = np.concatenate(all_patches, axis=0)
71
+
72
+ def predict(self, image_path) :
73
+ """
74
+ Run anomaly detection inference
75
+ """
76
+ image = self._load_image(image_path)
77
+ patch_feats, grid_size = self._extract_patch_features(image)
78
+
79
+ patch_mask = self._compute_patch_mask(image, grid_size)
80
+
81
+ # Compute distances only on foreground patches
82
+ scores_fg = self._knn_distances(patch_feats[patch_mask])
83
+
84
+ # Put scores back into full patch grid
85
+ scores_all = np.zeros(patch_feats.shape[0], dtype=np.float32)
86
+ scores_all[patch_mask] = scores_fg
87
+ patch_map = scores_all.reshape(grid_size)
88
+
89
+ # Upsample to full image ( just for visualization)
90
+ h, w = image.shape[:2]
91
+ anomaly_map = cv2.resize(
92
+ patch_map,
93
+ (w, h),
94
+ interpolation=cv2.INTER_CUBIC,
95
+ ).astype(np.float32)
96
+
97
+ # Using image-level score - mean of top 1% patch scores
98
+ image_score = self._mean_top_percent(scores_fg, top_percent=1.0)
99
+
100
+ return image, anomaly_map, image_score
101
+
102
+
103
+ def _load_image(self, path):
104
+ bgr = cv2.imread(path, cv2.IMREAD_COLOR)
105
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
106
+ return rgb
107
+
108
+ @staticmethod
109
+ def _l2_normalize(feats: np.ndarray, eps: float = 1e-12) -> np.ndarray:
110
+ norm = np.linalg.norm(feats, axis=1, keepdims=True)
111
+ return feats / np.maximum(norm, eps)
112
+
113
+ def _extract_patch_features(self, image: np.ndarray) :
114
+ """
115
+ Run backbone on a single image and return patch features
116
+ """
117
+ pil_resized, _ = self._resize_for_model(image)
118
+ x = self.transform(pil_resized).unsqueeze(0).to(self.device)
119
+
120
+ with torch.inference_mode():
121
+ out = self.backbone.forward_features(x)
122
+
123
+ tokens = out.get("x_norm_patchtokens") if isinstance(out, dict) else out
124
+
125
+ if tokens is None and isinstance(out, dict):
126
+ tokens = out.get("x")
127
+ if tokens is not None and tokens.ndim == 4:
128
+ B, C, Hf, Wf = tokens.shape
129
+ tokens = tokens.permute(0, 2, 3, 1).reshape(B, Hf * Wf, C)
130
+
131
+ B, N, C = tokens.shape
132
+
133
+ if hasattr(self.backbone, "patch_embed") and hasattr(self.backbone.patch_embed, "grid_size"):
134
+ gh, gw = self.backbone.patch_embed.grid_size
135
+ else:
136
+ gh = int(np.sqrt(N))
137
+ gw = max(1, N // max(1, gh))
138
+
139
+ n_patches = gh * gw
140
+ patch_tokens = tokens[:, -n_patches:, :]
141
+
142
+ # Flatten and normalize
143
+ feats = (
144
+ patch_tokens.reshape(B * n_patches, C)
145
+ .detach()
146
+ .cpu()
147
+ .numpy()
148
+ .astype("float32")
149
+ )
150
+ feats = self._l2_normalize(feats)
151
+
152
+ grid_size = (gh, gw)
153
+ return feats, grid_size
154
+
155
+ def _compute_patch_mask(self,image,grid_size) :
156
+ """
157
+ Convert a pixel-level mask to patch-level mask.
158
+ """
159
+ h_p, w_p = grid_size
160
+ n_patches = h_p * w_p
161
+
162
+ if self.segmenter is None:
163
+ return np.ones(n_patches, dtype=bool)
164
+
165
+ # Resize image same way as in the backbone before sending to SAM
166
+ pil_resized, np_resized = self._resize_for_model(image)
167
+ full_mask = self.segmenter.get_object_mask(np_resized)
168
+
169
+ # Optionally visualize in resized space
170
+ visualize_segmentation(
171
+ np_resized,
172
+ full_mask,
173
+ grid_size=None,
174
+ title=f"Segmentation debug (resized {self.img_size})",
175
+ )
176
+
177
+ full_mask_uint8 = (full_mask.astype(np.uint8) * 255).astype(np.float32)
178
+
179
+ # Downsample to patch grid with area interpolation for coverage
180
+ mask_small = cv2.resize(
181
+ full_mask_uint8,
182
+ (w_p, h_p),
183
+ interpolation=cv2.INTER_AREA,
184
+ ) / 255.0
185
+
186
+ patch_mask = (mask_small >= 0.5).reshape(-1)
187
+
188
+ # Fallback if mask collapses
189
+ fg_ratio = patch_mask.mean()
190
+ if fg_ratio < 0.01 or fg_ratio > 0.99:
191
+ patch_mask = np.ones(n_patches, dtype=bool)
192
+
193
+ return patch_mask
194
+
195
+ def _resize_for_model(self, image):
196
+ pil = Image.fromarray(image)
197
+ pil_resized = pil.resize((self.img_size, self.img_size), Image.BICUBIC)
198
+ np_resized = np.array(pil_resized)
199
+ return pil_resized, np_resized
200
+
201
+ def _knn_distances(self, feats: np.ndarray) -> np.ndarray:
202
+ """
203
+ Compute distance of each query feature to its nearest neighbors in the memory bank.
204
+
205
+ """
206
+ if self.memory_bank is None:
207
+ raise RuntimeError("Memory bank is empty.")
208
+
209
+
210
+ a = feats
211
+ b = self.memory_bank
212
+
213
+ # vectorize version of L2 distances
214
+ a2 = np.sum(a**2, axis=1, keepdims=True)
215
+ b2 = np.sum(b**2, axis=1, keepdims=True).T
216
+ ab = a @ b.T
217
+ # Clip to avoid negative values
218
+ d2 = np.clip(a2 + b2 - 2.0 * ab, a_min=0.0, a_max=None)
219
+ d = np.sqrt(d2)
220
+
221
+ # kNN: take mean of k smallest distances per patch
222
+ k = min(self.k_neighbors, d.shape[1])
223
+ if k == 1:
224
+ min_d = d.min(axis=1)
225
+ else:
226
+ # partial sort for efficiency
227
+ part = np.partition(d, kth=k - 1, axis=1)[:, :k]
228
+ min_d = part.mean(axis=1)
229
+
230
+ return min_d.astype(np.float32)
231
+
232
+ @staticmethod
233
+ def _mean_top_percent(values: np.ndarray, top_percent: float = 1.0) -> float:
234
+ """Mean of top p% values used as image level anomaly score."""
235
+ if values.size == 0:
236
+ return 0.0
237
+ k = max(1, int(round(values.size * (top_percent / 100.0))))
238
+ part = np.partition(values, -k)[-k:]
239
+ return float(part.mean())
segmenters/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from segmenters.base_segmenter import BaseSegmenter
2
+
3
+
4
+ from segmenters.sam3 import SAM3Segmenter
5
+
6
+ from segmenters.pca_segmenter import PCASegmenter
7
+
8
+ """
9
+ Model registry for segmenters
10
+ """
11
+ _SEGMENTERS = {}
12
+ if SAM3Segmenter is not None:
13
+ _SEGMENTERS["sam3"] = SAM3Segmenter
14
+ if PCASegmenter is not None:
15
+ _SEGMENTERS["pca"] = PCASegmenter
16
+
17
+
18
+ def get_segmenter(name: str, **kwargs):
19
+ if name not in _SEGMENTERS:
20
+ raise ValueError(f"Unknown segmenter '{name}'. Available: {list(_SEGMENTERS)}")
21
+ return _SEGMENTERS[name](**kwargs)
segmenters/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (799 Bytes). View file
 
segmenters/__pycache__/base_segmenter.cpython-312.pyc ADDED
Binary file (811 Bytes). View file
 
segmenters/__pycache__/pca_segmenter.cpython-312.pyc ADDED
Binary file (6.98 kB). View file
 
segmenters/__pycache__/sam3.cpython-312.pyc ADDED
Binary file (4.61 kB). View file
 
segmenters/base_segmenter.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+
4
+ class BaseSegmenter:
5
+ """Base class for segmentation models """
6
+ def get_object_mask(self, image: np.ndarray) -> np.ndarray:
7
+ """
8
+ Args:
9
+ image
10
+ Returns:
11
+ bool mask of shape, where True = foreground object.
12
+ """
13
+ raise NotImplementedError
14
+
segmenters/pca_segmenter.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ from PIL import Image
6
+ import torchvision.transforms as T
7
+ from timm.data import resolve_data_config
8
+
9
+ from backbones import get_backbone
10
+ from segmenters import BaseSegmenter
11
+
12
+
13
+ class PCASegmenter(BaseSegmenter):
14
+ def __init__(
15
+ self,
16
+ backbone_name: str = "dinov3_base",
17
+ device: str | None = None,
18
+ threshold: float = 2.5,
19
+ kernel_size: int = 5,
20
+ border: float = 0.2,
21
+ ):
22
+ super().__init__()
23
+ self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
24
+ self.model = get_backbone(backbone_name).to(self.device)
25
+ self.model.eval()
26
+
27
+
28
+ cfg = resolve_data_config({}, model=self.model)
29
+ _, img_size, _ = cfg["input_size"]
30
+ arch = getattr(getattr(self.model, "pretrained_cfg", {}), "get", lambda k, d=None: {})( # type: ignore[arg-type]
31
+ "architecture", ""
32
+ )
33
+ if isinstance(arch, str) and "dinov3" in arch:
34
+ img_size = max(img_size, 512)
35
+
36
+ self.img_size = img_size
37
+ interp = cfg.get("interpolation", "bicubic")
38
+ self.transform = T.Compose(
39
+ [
40
+ T.Resize((self.img_size, self.img_size), interpolation=getattr(T.InterpolationMode, interp.upper(), T.InterpolationMode.BICUBIC)),
41
+ T.ToTensor(),
42
+ T.Normalize(mean=cfg.get("mean", (0.485, 0.456, 0.406)), std=cfg.get("std", (0.229, 0.224, 0.225))),
43
+ ]
44
+ )
45
+ self.threshold = threshold
46
+ self.border = border
47
+ self.kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
48
+
49
+ def get_object_mask(self, image: np.ndarray) -> np.ndarray:
50
+ h0, w0 = image.shape[:2]
51
+ pil = Image.fromarray(image.astype(np.uint8))
52
+ x = self.transform(pil).unsqueeze(0).to(self.device)
53
+
54
+ with torch.inference_mode():
55
+ out = self.model.forward_features(x)
56
+
57
+ tokens = out.get("x_norm_patchtokens") if isinstance(out, dict) else out
58
+
59
+ if tokens is None and isinstance(out, dict):
60
+ tokens = out.get("x")
61
+ if tokens is not None and tokens.ndim == 4:
62
+ B, C, Hf, Wf = tokens.shape
63
+ tokens = tokens.permute(0, 2, 3, 1).reshape(B, Hf * Wf, C)
64
+
65
+ gh_dyn = int(np.sqrt(tokens.shape[1]))
66
+ gw_dyn = max(1, tokens.shape[1] // max(1, gh_dyn))
67
+ gh, gw = gh_dyn, gw_dyn
68
+
69
+ if hasattr(self.model, "patch_embed") and hasattr(self.model.patch_embed, "grid_size"):
70
+ gh0, gw0 = self.model.patch_embed.grid_size
71
+ if gh0 * gw0 == tokens.shape[1]:
72
+ gh, gw = gh0, gw0
73
+ n_patches = gh * gw
74
+ tokens = tokens[:, -n_patches:, :]
75
+
76
+ feats = tokens.squeeze(0).detach().cpu().numpy().astype(np.float32)
77
+ feats -= feats.mean(0, keepdims=True)
78
+ u, s, vh = np.linalg.svd(feats, full_matrices=False)
79
+ pc1 = vh[0]
80
+ scores = feats @ pc1
81
+ mask = scores > self.threshold
82
+ m_grid = mask.reshape(gh, gw)
83
+ bh = int(gh * self.border)
84
+ bw = int(gw * self.border)
85
+ inner = m_grid[bh : gh - bh, bw : gw - bw]
86
+ if inner.size > 0 and inner.mean() <= 0.35:
87
+ mask = scores < -self.threshold
88
+ m_grid = mask.reshape(gh, gw)
89
+ mask = m_grid.astype(np.uint8)
90
+ mask = cv2.dilate(mask, self.kernel, iterations=1)
91
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, self.kernel)
92
+ mask = cv2.resize(mask, (w0, h0), interpolation=cv2.INTER_NEAREST)
93
+ return mask.astype(bool)
segmenters/sam2.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import Sam2Processor, Sam2Model
6
+
7
+ from segmenters import BaseSegmenter
8
+
9
+
10
+ class SAM2Segmenter(BaseSegmenter):
11
+ """
12
+ SAM2 wrapper.
13
+
14
+ - Uses Sam2Model (e.g. `facebook/sam2.1-hiera-large`).
15
+ - Segments (approximately) all objects in the image by prompting
16
+ with a full-image bounding box and returns a single boolean mask
17
+ given by the union of all predicted masks.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ text_prompt: str | None = None,
23
+ model_name: str = "facebook/sam2.1-hiera-large",
24
+ device: str = "cuda",
25
+ mask_threshold: float = 0.5,
26
+ ) -> None:
27
+ """
28
+ Args:
29
+ text_prompt: kept for compatibility with SAM3Segmenter, but unused.
30
+ model_name: HF repo id for the SAM2 model, e.g. "facebook/sam2.1-hiera-large".
31
+ device: "cuda" or "cpu".
32
+ mask_threshold: pixel threshold for masks (after SAM2 post-processing).
33
+ """
34
+ super().__init__()
35
+
36
+ if torch.cuda.is_available() and device.startswith("cuda"):
37
+ self.device = torch.device(device)
38
+ else:
39
+ self.device = torch.device("cpu")
40
+
41
+ # Load SAM2 model + processor
42
+ self.model = Sam2Model.from_pretrained(model_name).to(self.device)
43
+ self.model.eval()
44
+ self.processor = Sam2Processor.from_pretrained(model_name)
45
+
46
+ def get_object_mask(self, image: np.ndarray) -> np.ndarray:
47
+ """
48
+ Run SAM2 and return a single foreground mask.
49
+
50
+ - Convert image to PIL.
51
+ - Use a single bounding box covering the whole image as prompt.
52
+ - Run SAM2, post-process masks to image resolution.
53
+ - Threshold and union all masks into one boolean (H, W) array.
54
+
55
+ """
56
+ # Ensure PIL image
57
+ if isinstance(image, np.ndarray):
58
+ pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
59
+ else:
60
+ pil_image = image
61
+
62
+ W, H = pil_image.size # PIL: (W, H)
63
+
64
+ # Full image bounding box: [x_min, y_min, x_max, y_max]
65
+ input_boxes = [[[0, 0, W, H]]]
66
+
67
+ # Build inputs for SAM2
68
+ inputs = self.processor(
69
+ images=pil_image,
70
+ input_boxes=input_boxes,
71
+ return_tensors="pt",
72
+ ).to(self.device)
73
+
74
+ with torch.no_grad():
75
+ # multimask_output=False → one mask per box
76
+ outputs = self.model(**inputs, multimask_output=False)
77
+
78
+ # Post-process masks to original resolution
79
+ masks = self.processor.post_process_masks(
80
+ outputs.pred_masks.cpu(), # (B, num_masks, H', W')
81
+ inputs["original_sizes"],
82
+ )[0]
83
+
84
+ # Shapes can be:
85
+ # - (num_masks, H, W)
86
+ # - or (1, num_masks, H, W) depending on version
87
+ if masks.ndim == 4:
88
+ # (B, num_masks, H, W) -> (num_masks, H, W) for B=1
89
+ masks = masks[0]
90
+
91
+ if masks.ndim == 2:
92
+ # Single mask: (H, W)
93
+ full_mask = (masks > self.mask_threshold).numpy().astype(bool)
94
+ return full_mask
95
+
96
+ if masks.ndim != 3:
97
+ # Failsafe: if something weird happens, keep everything
98
+ return np.ones((H, W), dtype=bool)
99
+
100
+ # masks: (num_masks, H, W)
101
+ masks_bin = masks > self.mask_threshold
102
+ combined = masks_bin.any(dim=0) # (H, W)
103
+ full_mask = combined.numpy().astype(bool)
104
+
105
+ return full_mask
segmenters/sam3.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import Sam3Processor, Sam3Model
7
+
8
+ from segmenters import BaseSegmenter
9
+
10
+
11
+ class SAM3Segmenter(BaseSegmenter):
12
+ """
13
+ SAM3 wrapper using a text prompt of object type
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ text_prompt: str,
19
+ model_name: str = "facebook/sam3",
20
+ device: str = "cuda",
21
+ score_threshold: float = 0.5,
22
+ mask_threshold: float = 0.5 ):
23
+ """
24
+ Args:
25
+ text_prompt: stuff we want to segment.
26
+ model_name: HF repo id for the SAM3 model.
27
+ device: "cuda" or "cpu".
28
+ score_threshold: min detection score to keep an instance.
29
+ mask_threshold: pixel threshold for masks.
30
+ """
31
+ super().__init__()
32
+
33
+ if torch.cuda.is_available() and device.startswith("cuda"):
34
+ self.device = torch.device(device)
35
+ else:
36
+ self.device = torch.device("cpu")
37
+
38
+ # preprocess text prompt so metal_nut is processed as metal nut
39
+ self.text_prompt = text_prompt.replace("_", " ")
40
+ self.score_threshold = score_threshold
41
+ self.mask_threshold = mask_threshold
42
+
43
+ # Loading model model + defining processor
44
+ self.model = Sam3Model.from_pretrained(model_name).to(self.device)
45
+ self.model.eval()
46
+ self.processor = Sam3Processor.from_pretrained(model_name)
47
+
48
+ def get_object_mask(self, image: np.ndarray) -> np.ndarray:
49
+ """
50
+ Running SAM3 and returning a single foreground mask.
51
+ """
52
+ # Pill image stuff - probably there is less idiotic way, but it is wat ChatGPT suggested
53
+ if isinstance(image, np.ndarray):
54
+ pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
55
+ else:
56
+ pil_image = image
57
+
58
+ # defining preprocessor with text prompt
59
+ inputs = self.processor(
60
+ images=pil_image,
61
+ text=self.text_prompt,
62
+ return_tensors="pt",
63
+ ).to(self.device)
64
+
65
+ with torch.no_grad():
66
+ outputs = self.model(**inputs)
67
+
68
+ # Post-process instance segmentation
69
+ target_sizes = inputs.get("original_sizes").tolist()
70
+ results = self.processor.post_process_instance_segmentation(
71
+ outputs,
72
+ threshold=self.score_threshold,
73
+ mask_threshold=self.mask_threshold,
74
+ target_sizes=target_sizes,
75
+ )[0]
76
+
77
+ masks = results.get("masks", None)
78
+ scores = results.get("scores", None)
79
+
80
+ # If SAM completely fails we keep everything
81
+ if masks is None or masks.numel() == 0:
82
+ H, W = pil_image.size[1], pil_image.size[0]
83
+ return np.ones((H, W), dtype=bool)
84
+
85
+ if scores is not None:
86
+ keep = scores >= self.score_threshold
87
+ if keep.sum() == 0:
88
+ H, W = pil_image.size[1], pil_image.size[0]
89
+ return np.ones((H, W), dtype=bool)
90
+ masks = masks[keep]
91
+
92
+ # check if mask passes mask treshold
93
+ masks_bin = (masks > self.mask_threshold)
94
+ combined = masks_bin.any(dim=0)
95
+ full_mask = combined.cpu().numpy().astype(bool)
96
+
97
+ return full_mask
utils/__pycache__/visualize.cpython-312.pyc ADDED
Binary file (5.12 kB). View file
 
utils/visualize.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+
6
+ #Completeöy vibecoded
7
+ def visualize_prediction(
8
+ image: np.ndarray,
9
+ anomaly_map: np.ndarray,
10
+ image_score: float,
11
+ threshold_percentile: float = 95.0,
12
+ title: str | None = None,
13
+ ) -> None:
14
+ """
15
+ Show:
16
+ - original image
17
+ - heatmap overlay
18
+ - binary mask overlay (thresholded)
19
+ """
20
+ # Normalize anomaly map to [0, 1] for visualization
21
+ amap = anomaly_map.astype(np.float32)
22
+ amap -= amap.min()
23
+ if amap.max() > 0:
24
+ amap /= amap.max()
25
+
26
+ thresh = np.percentile(amap, threshold_percentile)
27
+ binary = amap >= thresh
28
+
29
+ fig, axes = plt.subplots(1, 3, figsize=(14, 4))
30
+
31
+ axes[0].imshow(image)
32
+ axes[0].set_title("Input image")
33
+ axes[0].axis("off")
34
+
35
+ axes[1].imshow(image)
36
+ im = axes[1].imshow(amap, cmap="jet", alpha=0.5)
37
+ axes[1].set_title("Anomaly heatmap")
38
+ axes[1].axis("off")
39
+ fig.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
40
+
41
+ axes[2].imshow(image)
42
+ axes[2].imshow(binary, cmap="gray", alpha=0.5)
43
+ axes[2].set_title(f"Thresholded (>{threshold_percentile:.0f}%)")
44
+ axes[2].axis("off")
45
+
46
+ if title is None:
47
+ title = f"Image anomaly score: {image_score:.3f}"
48
+
49
+ fig.suptitle(title)
50
+ plt.tight_layout()
51
+ plt.show()
52
+
53
+ def visualize_segmentation(
54
+ image: np.ndarray,
55
+ full_mask: np.ndarray,
56
+ grid_size: tuple[int, int] | None = None,
57
+ title: str | None = None,
58
+ ) -> None:
59
+ """
60
+ Visualize SAM2 segmentation.
61
+
62
+ Args:
63
+ image: (H, W, 3) RGB uint8
64
+ full_mask: (H, W) bool or 0/1 array from SAM2
65
+ grid_size: optional (H_patches, W_patches) to also show patch-level mask
66
+ title: optional title string
67
+ """
68
+ img = image
69
+ mask = full_mask.astype(bool)
70
+ H, W = mask.shape
71
+
72
+ # Prepare figure layout
73
+ n_cols = 3 if grid_size is None else 4
74
+ fig, axes = plt.subplots(1, n_cols, figsize=(4 * n_cols, 4))
75
+
76
+ # 1) input image
77
+ axes[0].imshow(img)
78
+ axes[0].set_title("Input image")
79
+ axes[0].axis("off")
80
+
81
+ # 2) raw binary mask
82
+ axes[1].imshow(mask, cmap="gray")
83
+ axes[1].set_title("SAM2 mask (full-res)")
84
+ axes[1].axis("off")
85
+
86
+ # 3) overlay mask on image
87
+ axes[2].imshow(img)
88
+ axes[2].imshow(mask, cmap="Reds", alpha=0.4)
89
+ axes[2].set_title("Mask overlay")
90
+ axes[2].axis("off")
91
+
92
+ # 4) optional patch-level mask (after downsampling)
93
+ if grid_size is not None:
94
+ gh, gw = grid_size
95
+ # downsample full mask to patch grid and back up to image size
96
+ patch_mask_small = cv2.resize(
97
+ mask.astype(np.uint8), (gw, gh), interpolation=cv2.INTER_NEAREST
98
+ ).astype(bool)
99
+ patch_mask_full = cv2.resize(
100
+ patch_mask_small.astype(np.uint8),
101
+ (W, H),
102
+ interpolation=cv2.INTER_NEAREST,
103
+ ).astype(bool)
104
+
105
+ axes[3].imshow(img)
106
+ axes[3].imshow(patch_mask_full, cmap="Blues", alpha=0.4)
107
+ axes[3].set_title("Patch-level mask (after downsample)")
108
+ axes[3].axis("off")
109
+
110
+ if title is not None:
111
+ fig.suptitle(title)
112
+
113
+ plt.tight_layout()
114
+ plt.show()