File size: 4,616 Bytes
1834bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

from __future__ import annotations
import os
import cv2
import numpy as np
import torch

from segmenters import PCASegmenter
from segmenters.sam3 import SAM3Segmenter
from backbones import get_backbone
from dataset.dataloader import load_mvtec
from models.model_bank_knn import PatchKNNDetector
from evaluation.anomaly_evaluator import AnomalyEvaluator 

def main(

    category: str = "bottle",

    root: str | None = None,

    backbone_name: str = "dinov3_small",

    use_sam3: bool = True,

    use_pca: bool = False,

    pca_backbone_name: str | None = None,

    return_results: bool = False,

    backbone_model=None,

    segmenter_obj=None,

    n_ref: int = 1,

):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # 1. Setup
    root = root or os.path.join("dataset", "mvtec_anomaly_detection")
    train_paths, test_paths = load_mvtec(category=category, root=root)
    test_paths = test_paths[::5]
    train_paths = train_paths[::5]
    print(f"{category}: {len(train_paths)} train, {len(test_paths)} test images")

    # 2. Initialize Evaluator with PRO support
    evaluator = AnomalyEvaluator(pixel_subsample_rate=0.01, compute_pro=False)

    # 3. Model Init
    segmenter = segmenter_obj
    if segmenter is None:
        if use_pca:
            pca_backbone = pca_backbone_name or backbone_name
            segmenter = PCASegmenter(backbone_name=pca_backbone, device=device)
            print(f"Using PCA segmenter (backbone={pca_backbone})")
        elif use_sam3:
            segmenter = SAM3Segmenter(text_prompt=category, device=device)
            print("Using SAM3 segmenter")
        else:
            print("No segmenter selected; using full-image foreground.")
        

    backbone = backbone_model or get_backbone(backbone_name)

    model = PatchKNNDetector(
        backbone=backbone,
        segmenter=segmenter,
        device=device,
        k_neighbors=1,
    )

    print(f"Fitting model... (backbone={backbone_name}, sam3={use_sam3})")
    model.fit(train_paths, n_ref=n_ref)

    # 4. Evaluation Loop
    print(f"Starting evaluation on {len(test_paths)} images...")

    for i, path in enumerate(test_paths):
        # Predict
        image, amap, score = model.predict(path)

        # Ground Truth Logic
        is_anomaly = 0 if "good" in path else 1
        
        if is_anomaly == 0:
            gt_mask = np.zeros_like(amap)
        else:
            mask_path = path.replace("test", "ground_truth").replace(".png", "_mask.png")
            if os.path.exists(mask_path):
                gt_mask = cv2.imread(mask_path, 0)
                if gt_mask.shape != amap.shape:
                    gt_mask = cv2.resize(gt_mask, (amap.shape[1], amap.shape[0]), interpolation=cv2.INTER_NEAREST)
                gt_mask = (gt_mask > 0).astype(int)
            else:
                gt_mask = np.zeros_like(amap)

        # Update Evaluator
        evaluator.update(image_score=score, gt_label=is_anomaly, anomaly_map=amap, gt_mask=gt_mask)

        if i % 20 == 0:
            print(f"Processed {i}/{len(test_paths)}...")

    # 5. Compute & Print Results
    results = evaluator.compute()
    
    print("\n" + "="*40)
    print(f"FINAL RESULTS: {category}")
    print("-" * 40)
    
    # Image Level
    print(f"Image AUROC:   {results['image_auroc']:.4f}")
    print(f"Image F1-Max:  {results['image_f1_max']:.4f}")
    print(f"Image AP:      {results['image_ap']:.4f}")
    print("-" * 40)
    
    # Pixel Level
    if 'pixel_auroc' in results:
        print(f"Pixel AUROC:   {results['pixel_auroc']:.4f}")
        print(f"Pixel F1-Max:  {results['pixel_f1_max']:.4f}")
        #print(f"PRO Score:     {results['pixel_pro']:.4f}") 
    print("="*40)

    return results if return_results else None

if __name__ == "__main__":


    category = os.environ.get("MVTec_CATEGORY", "bottle")
    root = os.environ.get("MVTec_ROOT", None)
    backbone_name = os.environ.get("BACKBONE_NAME", "dinov3_small")
    use_sam3_env = os.environ.get("USE_SAM3", "0").lower()
    use_sam3 = use_sam3_env not in {"0", "false", "no"}
    use_pca_env = os.environ.get("USE_PCA", "1").lower()
    use_pca = use_pca_env in {"1", "true", "yes"}
    pca_backbone_name = os.environ.get("PCA_BACKBONE", None)

    main(
        category=category,
        root=root,
        backbone_name=backbone_name,
        use_sam3=use_sam3,
        use_pca=use_pca,
        pca_backbone_name=pca_backbone_name,
        n_ref=int(os.environ.get("N_REF", 1)),
    )