File size: 6,381 Bytes
ed4e653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
PatchCore Anomaly Detection — Gradio demo.

Loads per-category models from results/{category}/patchcore_model.pt and runs
CPU inference. Thresholds are not stored in metrics.json (only AUROC/PRO metrics
are saved), so the verdict threshold defaults to 0.5 — adjust FALLBACK_THRESHOLD
below after inspecting per-category score distributions.
"""

import json
import sys
from pathlib import Path

import numpy as np
import matplotlib.cm as cm
import torch
import torchvision.transforms as T
from PIL import Image


THRESHOLDS_FILE = Path(__file__).parent / "thresholds.json"
_thresholds = json.loads(THRESHOLDS_FILE.read_text()) if THRESHOLDS_FILE.exists() else {}

def load_threshold(category: str) -> float:
    if category in _thresholds:
        return _thresholds[category]["threshold"]   # calibrated from train/good/
    return 0.5   # fallback — run calibrate_thresholds.py to fix this

REPO_ROOT   = Path(__file__).parent
RESULTS_DIR = REPO_ROOT / "results"
DATA_DIR    = REPO_ROOT / "anomaly_ds"

sys.path.insert(0, str(REPO_ROOT / "src"))
from patchcore import PatchCore


# ---------------------------------------------------------------------------
# Category discovery
# ---------------------------------------------------------------------------

def discover_categories() -> list[str]:
    return sorted(
        d.name for d in RESULTS_DIR.iterdir()
        if d.is_dir() and (d / "patchcore_model.pt").exists()
    )


AVAILABLE_CATEGORIES = discover_categories()


# ---------------------------------------------------------------------------
# Model cache (loaded on first use, stays in RAM)
# ---------------------------------------------------------------------------

_model_cache: dict = {}


def load_model(category: str) -> PatchCore:
    if category in _model_cache:
        return _model_cache[category]

    model_path = str(RESULTS_DIR / category / "patchcore_model.pt")
    model = PatchCore(device="cpu", faiss_gpu=False)
    model.load(model_path)
    _model_cache[category] = model
    return model


# ---------------------------------------------------------------------------
# Preprocessing — must match IMAGE_TRANSFORM in src/dataset.py exactly:
#   Resize(256) → CenterCrop(224) → ToTensor → Normalize(ImageNet)
# ---------------------------------------------------------------------------

_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])


# ---------------------------------------------------------------------------
# Heatmap overlay (visualize.py only exposes a private numpy helper;
# this PIL version is self-contained and avoids the Agg backend conflict)
# ---------------------------------------------------------------------------

def overlay_heatmap(image: Image.Image, anomaly_map: np.ndarray, alpha: float = 0.45) -> Image.Image:
    image_rgb = image.resize((224, 224)).convert("RGB")
    norm = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min() + 1e-8)
    heatmap_rgba = cm.jet(norm)
    heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
    heatmap_pil = Image.fromarray(heatmap_rgb).resize(image_rgb.size)
    return Image.blend(image_rgb, heatmap_pil, alpha)


# ---------------------------------------------------------------------------
# Main predict function
# ---------------------------------------------------------------------------

def predict(category: str, image: Image.Image):
    if image is None:
        return None, "No image provided.", None

    model = load_model(category)
    threshold = load_threshold(category)

    tensor = _transform(image.convert("RGB")).unsqueeze(0)  # [1, 3, 224, 224]

    with torch.no_grad():
        # predict() returns (image_score: float, anomaly_map: np.ndarray [224,224])
        image_score, anomaly_map = model.predict(tensor)

    verdict = "✅  NORMAL" if image_score < threshold else "❌  ANOMALY"
    heatmap = overlay_heatmap(image, anomaly_map)

    return float(image_score), verdict, heatmap


# ---------------------------------------------------------------------------
# Example images (first 4 categories, one normal + one defective each)
# ---------------------------------------------------------------------------

def build_examples() -> list:
    examples_dir = REPO_ROOT / "examples"
    if not examples_dir.exists():
        return []
    examples = []
    for img_path in sorted(examples_dir.glob("*.png")):
        cat = img_path.stem.rsplit("_", 1)[0]   # "bottle_good" → "bottle"
        if cat in AVAILABLE_CATEGORIES:
            examples.append([cat, str(img_path)])
    return examples


EXAMPLES = build_examples()


# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------

import gradio as gr

with gr.Blocks(title="PatchCore — Industrial Anomaly Detection") as demo:
    gr.Markdown("## PatchCore — Industrial Anomaly Detection")
    gr.Markdown(
        "Select a product category, upload an image (or use your camera), "
        "and see the anomaly score and a pixel-level heatmap.\n\n"
        f"**Available categories ({len(AVAILABLE_CATEGORIES)}):** "
        + ", ".join(AVAILABLE_CATEGORIES)
    )

    with gr.Row():
        with gr.Column(scale=1):
            category_dd = gr.Dropdown(
                choices=AVAILABLE_CATEGORIES,
                value=AVAILABLE_CATEGORIES[0],
                label="Product category",
            )
            image_in = gr.Image(
                type="pil",
                label="Input image",
                sources=["upload", "webcam"],
            )
            run_btn = gr.Button("Detect Anomalies", variant="primary")

        with gr.Column(scale=1):
            score_out   = gr.Number(label="Anomaly score")
            verdict_out = gr.Textbox(label="Verdict")
            heatmap_out = gr.Image(type="pil", label="Anomaly heatmap overlay")

    gr.Examples(
        examples=EXAMPLES,
        inputs=[category_dd, image_in],
        label="Try an example",
    )

    run_btn.click(
        fn=predict,
        inputs=[category_dd, image_in],
        outputs=[score_out, verdict_out, heatmap_out],
    )

if __name__ == "__main__":
    demo.launch(theme=gr.themes.Soft())