""" PatchCore Anomaly Detection — Gradio demo. Loads per-category models from results/{category}/patchcore_model.pt and runs CPU inference. Thresholds are not stored in metrics.json (only AUROC/PRO metrics are saved), so the verdict threshold defaults to 0.5 — adjust FALLBACK_THRESHOLD below after inspecting per-category score distributions. """ import json import sys from pathlib import Path import numpy as np import matplotlib.cm as cm import torch import torchvision.transforms as T from PIL import Image THRESHOLDS_FILE = Path(__file__).parent / "thresholds.json" _thresholds = json.loads(THRESHOLDS_FILE.read_text()) if THRESHOLDS_FILE.exists() else {} def load_threshold(category: str) -> float: if category in _thresholds: return _thresholds[category]["threshold"] # calibrated from train/good/ return 0.5 # fallback — run calibrate_thresholds.py to fix this REPO_ROOT = Path(__file__).parent RESULTS_DIR = REPO_ROOT / "results" DATA_DIR = REPO_ROOT / "anomaly_ds" sys.path.insert(0, str(REPO_ROOT / "src")) from patchcore import PatchCore # --------------------------------------------------------------------------- # Category discovery # --------------------------------------------------------------------------- def discover_categories() -> list[str]: return sorted( d.name for d in RESULTS_DIR.iterdir() if d.is_dir() and (d / "patchcore_model.pt").exists() ) AVAILABLE_CATEGORIES = discover_categories() # --------------------------------------------------------------------------- # Model cache (loaded on first use, stays in RAM) # --------------------------------------------------------------------------- _model_cache: dict = {} def load_model(category: str) -> PatchCore: if category in _model_cache: return _model_cache[category] model_path = str(RESULTS_DIR / category / "patchcore_model.pt") model = PatchCore(device="cpu", faiss_gpu=False) model.load(model_path) _model_cache[category] = model return model # --------------------------------------------------------------------------- # Preprocessing — must match IMAGE_TRANSFORM in src/dataset.py exactly: # Resize(256) → CenterCrop(224) → ToTensor → Normalize(ImageNet) # --------------------------------------------------------------------------- _transform = T.Compose([ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # --------------------------------------------------------------------------- # Heatmap overlay (visualize.py only exposes a private numpy helper; # this PIL version is self-contained and avoids the Agg backend conflict) # --------------------------------------------------------------------------- def overlay_heatmap(image: Image.Image, anomaly_map: np.ndarray, alpha: float = 0.45) -> Image.Image: image_rgb = image.resize((224, 224)).convert("RGB") norm = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min() + 1e-8) heatmap_rgba = cm.jet(norm) heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8) heatmap_pil = Image.fromarray(heatmap_rgb).resize(image_rgb.size) return Image.blend(image_rgb, heatmap_pil, alpha) # --------------------------------------------------------------------------- # Main predict function # --------------------------------------------------------------------------- def predict(category: str, image: Image.Image): if image is None: return None, "No image provided.", None model = load_model(category) threshold = load_threshold(category) tensor = _transform(image.convert("RGB")).unsqueeze(0) # [1, 3, 224, 224] with torch.no_grad(): # predict() returns (image_score: float, anomaly_map: np.ndarray [224,224]) image_score, anomaly_map = model.predict(tensor) verdict = "✅ NORMAL" if image_score < threshold else "❌ ANOMALY" heatmap = overlay_heatmap(image, anomaly_map) return float(image_score), verdict, heatmap # --------------------------------------------------------------------------- # Example images (first 4 categories, one normal + one defective each) # --------------------------------------------------------------------------- def build_examples() -> list: examples_dir = REPO_ROOT / "examples" if not examples_dir.exists(): return [] examples = [] for img_path in sorted(examples_dir.glob("*.png")): cat = img_path.stem.rsplit("_", 1)[0] # "bottle_good" → "bottle" if cat in AVAILABLE_CATEGORIES: examples.append([cat, str(img_path)]) return examples EXAMPLES = build_examples() # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- import gradio as gr with gr.Blocks(title="PatchCore — Industrial Anomaly Detection") as demo: gr.Markdown("## PatchCore — Industrial Anomaly Detection") gr.Markdown( "Select a product category, upload an image (or use your camera), " "and see the anomaly score and a pixel-level heatmap.\n\n" f"**Available categories ({len(AVAILABLE_CATEGORIES)}):** " + ", ".join(AVAILABLE_CATEGORIES) ) with gr.Row(): with gr.Column(scale=1): category_dd = gr.Dropdown( choices=AVAILABLE_CATEGORIES, value=AVAILABLE_CATEGORIES[0], label="Product category", ) image_in = gr.Image( type="pil", label="Input image", sources=["upload", "webcam"], ) run_btn = gr.Button("Detect Anomalies", variant="primary") with gr.Column(scale=1): score_out = gr.Number(label="Anomaly score") verdict_out = gr.Textbox(label="Verdict") heatmap_out = gr.Image(type="pil", label="Anomaly heatmap overlay") gr.Examples( examples=EXAMPLES, inputs=[category_dd, image_in], label="Try an example", ) run_btn.click( fn=predict, inputs=[category_dd, image_in], outputs=[score_out, verdict_out, heatmap_out], ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())