| """ |
| PatchCore Anomaly Detection — Gradio demo. |
| |
| Loads per-category models from results/{category}/patchcore_model.pt and runs |
| CPU inference. Thresholds are not stored in metrics.json (only AUROC/PRO metrics |
| are saved), so the verdict threshold defaults to 0.5 — adjust FALLBACK_THRESHOLD |
| below after inspecting per-category score distributions. |
| """ |
|
|
| import json |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import matplotlib.cm as cm |
| import torch |
| import torchvision.transforms as T |
| from PIL import Image |
|
|
|
|
| THRESHOLDS_FILE = Path(__file__).parent / "thresholds.json" |
| _thresholds = json.loads(THRESHOLDS_FILE.read_text()) if THRESHOLDS_FILE.exists() else {} |
|
|
| def load_threshold(category: str) -> float: |
| if category in _thresholds: |
| return _thresholds[category]["threshold"] |
| return 0.5 |
|
|
| REPO_ROOT = Path(__file__).parent |
| RESULTS_DIR = REPO_ROOT / "results" |
| DATA_DIR = REPO_ROOT / "anomaly_ds" |
|
|
| sys.path.insert(0, str(REPO_ROOT / "src")) |
| from patchcore import PatchCore |
|
|
|
|
| |
| |
| |
|
|
| def discover_categories() -> list[str]: |
| return sorted( |
| d.name for d in RESULTS_DIR.iterdir() |
| if d.is_dir() and (d / "patchcore_model.pt").exists() |
| ) |
|
|
|
|
| AVAILABLE_CATEGORIES = discover_categories() |
|
|
|
|
| |
| |
| |
|
|
| _model_cache: dict = {} |
|
|
|
|
| def load_model(category: str) -> PatchCore: |
| if category in _model_cache: |
| return _model_cache[category] |
|
|
| model_path = str(RESULTS_DIR / category / "patchcore_model.pt") |
| model = PatchCore(device="cpu", faiss_gpu=False) |
| model.load(model_path) |
| _model_cache[category] = model |
| return model |
|
|
|
|
| |
| |
| |
| |
|
|
| _transform = T.Compose([ |
| T.Resize(256), |
| T.CenterCrop(224), |
| T.ToTensor(), |
| T.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
|
|
| |
| |
| |
| |
|
|
| def overlay_heatmap(image: Image.Image, anomaly_map: np.ndarray, alpha: float = 0.45) -> Image.Image: |
| image_rgb = image.resize((224, 224)).convert("RGB") |
| norm = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min() + 1e-8) |
| heatmap_rgba = cm.jet(norm) |
| heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8) |
| heatmap_pil = Image.fromarray(heatmap_rgb).resize(image_rgb.size) |
| return Image.blend(image_rgb, heatmap_pil, alpha) |
|
|
|
|
| |
| |
| |
|
|
| def predict(category: str, image: Image.Image): |
| if image is None: |
| return None, "No image provided.", None |
|
|
| model = load_model(category) |
| threshold = load_threshold(category) |
|
|
| tensor = _transform(image.convert("RGB")).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| |
| image_score, anomaly_map = model.predict(tensor) |
|
|
| verdict = "✅ NORMAL" if image_score < threshold else "❌ ANOMALY" |
| heatmap = overlay_heatmap(image, anomaly_map) |
|
|
| return float(image_score), verdict, heatmap |
|
|
|
|
| |
| |
| |
|
|
| def build_examples() -> list: |
| examples_dir = REPO_ROOT / "examples" |
| if not examples_dir.exists(): |
| return [] |
| examples = [] |
| for img_path in sorted(examples_dir.glob("*.png")): |
| cat = img_path.stem.rsplit("_", 1)[0] |
| if cat in AVAILABLE_CATEGORIES: |
| examples.append([cat, str(img_path)]) |
| return examples |
|
|
|
|
| EXAMPLES = build_examples() |
|
|
|
|
| |
| |
| |
|
|
| import gradio as gr |
|
|
| with gr.Blocks(title="PatchCore — Industrial Anomaly Detection") as demo: |
| gr.Markdown("## PatchCore — Industrial Anomaly Detection") |
| gr.Markdown( |
| "Select a product category, upload an image (or use your camera), " |
| "and see the anomaly score and a pixel-level heatmap.\n\n" |
| f"**Available categories ({len(AVAILABLE_CATEGORIES)}):** " |
| + ", ".join(AVAILABLE_CATEGORIES) |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| category_dd = gr.Dropdown( |
| choices=AVAILABLE_CATEGORIES, |
| value=AVAILABLE_CATEGORIES[0], |
| label="Product category", |
| ) |
| image_in = gr.Image( |
| type="pil", |
| label="Input image", |
| sources=["upload", "webcam"], |
| ) |
| run_btn = gr.Button("Detect Anomalies", variant="primary") |
|
|
| with gr.Column(scale=1): |
| score_out = gr.Number(label="Anomaly score") |
| verdict_out = gr.Textbox(label="Verdict") |
| heatmap_out = gr.Image(type="pil", label="Anomaly heatmap overlay") |
|
|
| gr.Examples( |
| examples=EXAMPLES, |
| inputs=[category_dd, image_in], |
| label="Try an example", |
| ) |
|
|
| run_btn.click( |
| fn=predict, |
| inputs=[category_dd, image_in], |
| outputs=[score_out, verdict_out, heatmap_out], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(theme=gr.themes.Soft()) |
|
|