File size: 6,381 Bytes
ed4e653 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | """
PatchCore Anomaly Detection — Gradio demo.
Loads per-category models from results/{category}/patchcore_model.pt and runs
CPU inference. Thresholds are not stored in metrics.json (only AUROC/PRO metrics
are saved), so the verdict threshold defaults to 0.5 — adjust FALLBACK_THRESHOLD
below after inspecting per-category score distributions.
"""
import json
import sys
from pathlib import Path
import numpy as np
import matplotlib.cm as cm
import torch
import torchvision.transforms as T
from PIL import Image
THRESHOLDS_FILE = Path(__file__).parent / "thresholds.json"
_thresholds = json.loads(THRESHOLDS_FILE.read_text()) if THRESHOLDS_FILE.exists() else {}
def load_threshold(category: str) -> float:
if category in _thresholds:
return _thresholds[category]["threshold"] # calibrated from train/good/
return 0.5 # fallback — run calibrate_thresholds.py to fix this
REPO_ROOT = Path(__file__).parent
RESULTS_DIR = REPO_ROOT / "results"
DATA_DIR = REPO_ROOT / "anomaly_ds"
sys.path.insert(0, str(REPO_ROOT / "src"))
from patchcore import PatchCore
# ---------------------------------------------------------------------------
# Category discovery
# ---------------------------------------------------------------------------
def discover_categories() -> list[str]:
return sorted(
d.name for d in RESULTS_DIR.iterdir()
if d.is_dir() and (d / "patchcore_model.pt").exists()
)
AVAILABLE_CATEGORIES = discover_categories()
# ---------------------------------------------------------------------------
# Model cache (loaded on first use, stays in RAM)
# ---------------------------------------------------------------------------
_model_cache: dict = {}
def load_model(category: str) -> PatchCore:
if category in _model_cache:
return _model_cache[category]
model_path = str(RESULTS_DIR / category / "patchcore_model.pt")
model = PatchCore(device="cpu", faiss_gpu=False)
model.load(model_path)
_model_cache[category] = model
return model
# ---------------------------------------------------------------------------
# Preprocessing — must match IMAGE_TRANSFORM in src/dataset.py exactly:
# Resize(256) → CenterCrop(224) → ToTensor → Normalize(ImageNet)
# ---------------------------------------------------------------------------
_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ---------------------------------------------------------------------------
# Heatmap overlay (visualize.py only exposes a private numpy helper;
# this PIL version is self-contained and avoids the Agg backend conflict)
# ---------------------------------------------------------------------------
def overlay_heatmap(image: Image.Image, anomaly_map: np.ndarray, alpha: float = 0.45) -> Image.Image:
image_rgb = image.resize((224, 224)).convert("RGB")
norm = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min() + 1e-8)
heatmap_rgba = cm.jet(norm)
heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
heatmap_pil = Image.fromarray(heatmap_rgb).resize(image_rgb.size)
return Image.blend(image_rgb, heatmap_pil, alpha)
# ---------------------------------------------------------------------------
# Main predict function
# ---------------------------------------------------------------------------
def predict(category: str, image: Image.Image):
if image is None:
return None, "No image provided.", None
model = load_model(category)
threshold = load_threshold(category)
tensor = _transform(image.convert("RGB")).unsqueeze(0) # [1, 3, 224, 224]
with torch.no_grad():
# predict() returns (image_score: float, anomaly_map: np.ndarray [224,224])
image_score, anomaly_map = model.predict(tensor)
verdict = "✅ NORMAL" if image_score < threshold else "❌ ANOMALY"
heatmap = overlay_heatmap(image, anomaly_map)
return float(image_score), verdict, heatmap
# ---------------------------------------------------------------------------
# Example images (first 4 categories, one normal + one defective each)
# ---------------------------------------------------------------------------
def build_examples() -> list:
examples_dir = REPO_ROOT / "examples"
if not examples_dir.exists():
return []
examples = []
for img_path in sorted(examples_dir.glob("*.png")):
cat = img_path.stem.rsplit("_", 1)[0] # "bottle_good" → "bottle"
if cat in AVAILABLE_CATEGORIES:
examples.append([cat, str(img_path)])
return examples
EXAMPLES = build_examples()
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
import gradio as gr
with gr.Blocks(title="PatchCore — Industrial Anomaly Detection") as demo:
gr.Markdown("## PatchCore — Industrial Anomaly Detection")
gr.Markdown(
"Select a product category, upload an image (or use your camera), "
"and see the anomaly score and a pixel-level heatmap.\n\n"
f"**Available categories ({len(AVAILABLE_CATEGORIES)}):** "
+ ", ".join(AVAILABLE_CATEGORIES)
)
with gr.Row():
with gr.Column(scale=1):
category_dd = gr.Dropdown(
choices=AVAILABLE_CATEGORIES,
value=AVAILABLE_CATEGORIES[0],
label="Product category",
)
image_in = gr.Image(
type="pil",
label="Input image",
sources=["upload", "webcam"],
)
run_btn = gr.Button("Detect Anomalies", variant="primary")
with gr.Column(scale=1):
score_out = gr.Number(label="Anomaly score")
verdict_out = gr.Textbox(label="Verdict")
heatmap_out = gr.Image(type="pil", label="Anomaly heatmap overlay")
gr.Examples(
examples=EXAMPLES,
inputs=[category_dd, image_in],
label="Try an example",
)
run_btn.click(
fn=predict,
inputs=[category_dd, image_in],
outputs=[score_out, verdict_out, heatmap_out],
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())
|