knownmax's picture
Upload folder using huggingface_hub
ed4e653 verified
Raw
History Blame Contribute Delete
6.38 kB
"""
PatchCore Anomaly Detection — Gradio demo.
Loads per-category models from results/{category}/patchcore_model.pt and runs
CPU inference. Thresholds are not stored in metrics.json (only AUROC/PRO metrics
are saved), so the verdict threshold defaults to 0.5 — adjust FALLBACK_THRESHOLD
below after inspecting per-category score distributions.
"""
import json
import sys
from pathlib import Path
import numpy as np
import matplotlib.cm as cm
import torch
import torchvision.transforms as T
from PIL import Image
THRESHOLDS_FILE = Path(__file__).parent / "thresholds.json"
_thresholds = json.loads(THRESHOLDS_FILE.read_text()) if THRESHOLDS_FILE.exists() else {}
def load_threshold(category: str) -> float:
if category in _thresholds:
return _thresholds[category]["threshold"] # calibrated from train/good/
return 0.5 # fallback — run calibrate_thresholds.py to fix this
REPO_ROOT = Path(__file__).parent
RESULTS_DIR = REPO_ROOT / "results"
DATA_DIR = REPO_ROOT / "anomaly_ds"
sys.path.insert(0, str(REPO_ROOT / "src"))
from patchcore import PatchCore
# ---------------------------------------------------------------------------
# Category discovery
# ---------------------------------------------------------------------------
def discover_categories() -> list[str]:
return sorted(
d.name for d in RESULTS_DIR.iterdir()
if d.is_dir() and (d / "patchcore_model.pt").exists()
)
AVAILABLE_CATEGORIES = discover_categories()
# ---------------------------------------------------------------------------
# Model cache (loaded on first use, stays in RAM)
# ---------------------------------------------------------------------------
_model_cache: dict = {}
def load_model(category: str) -> PatchCore:
if category in _model_cache:
return _model_cache[category]
model_path = str(RESULTS_DIR / category / "patchcore_model.pt")
model = PatchCore(device="cpu", faiss_gpu=False)
model.load(model_path)
_model_cache[category] = model
return model
# ---------------------------------------------------------------------------
# Preprocessing — must match IMAGE_TRANSFORM in src/dataset.py exactly:
# Resize(256) → CenterCrop(224) → ToTensor → Normalize(ImageNet)
# ---------------------------------------------------------------------------
_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ---------------------------------------------------------------------------
# Heatmap overlay (visualize.py only exposes a private numpy helper;
# this PIL version is self-contained and avoids the Agg backend conflict)
# ---------------------------------------------------------------------------
def overlay_heatmap(image: Image.Image, anomaly_map: np.ndarray, alpha: float = 0.45) -> Image.Image:
image_rgb = image.resize((224, 224)).convert("RGB")
norm = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min() + 1e-8)
heatmap_rgba = cm.jet(norm)
heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
heatmap_pil = Image.fromarray(heatmap_rgb).resize(image_rgb.size)
return Image.blend(image_rgb, heatmap_pil, alpha)
# ---------------------------------------------------------------------------
# Main predict function
# ---------------------------------------------------------------------------
def predict(category: str, image: Image.Image):
if image is None:
return None, "No image provided.", None
model = load_model(category)
threshold = load_threshold(category)
tensor = _transform(image.convert("RGB")).unsqueeze(0) # [1, 3, 224, 224]
with torch.no_grad():
# predict() returns (image_score: float, anomaly_map: np.ndarray [224,224])
image_score, anomaly_map = model.predict(tensor)
verdict = "✅ NORMAL" if image_score < threshold else "❌ ANOMALY"
heatmap = overlay_heatmap(image, anomaly_map)
return float(image_score), verdict, heatmap
# ---------------------------------------------------------------------------
# Example images (first 4 categories, one normal + one defective each)
# ---------------------------------------------------------------------------
def build_examples() -> list:
examples_dir = REPO_ROOT / "examples"
if not examples_dir.exists():
return []
examples = []
for img_path in sorted(examples_dir.glob("*.png")):
cat = img_path.stem.rsplit("_", 1)[0] # "bottle_good" → "bottle"
if cat in AVAILABLE_CATEGORIES:
examples.append([cat, str(img_path)])
return examples
EXAMPLES = build_examples()
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
import gradio as gr
with gr.Blocks(title="PatchCore — Industrial Anomaly Detection") as demo:
gr.Markdown("## PatchCore — Industrial Anomaly Detection")
gr.Markdown(
"Select a product category, upload an image (or use your camera), "
"and see the anomaly score and a pixel-level heatmap.\n\n"
f"**Available categories ({len(AVAILABLE_CATEGORIES)}):** "
+ ", ".join(AVAILABLE_CATEGORIES)
)
with gr.Row():
with gr.Column(scale=1):
category_dd = gr.Dropdown(
choices=AVAILABLE_CATEGORIES,
value=AVAILABLE_CATEGORIES[0],
label="Product category",
)
image_in = gr.Image(
type="pil",
label="Input image",
sources=["upload", "webcam"],
)
run_btn = gr.Button("Detect Anomalies", variant="primary")
with gr.Column(scale=1):
score_out = gr.Number(label="Anomaly score")
verdict_out = gr.Textbox(label="Verdict")
heatmap_out = gr.Image(type="pil", label="Anomaly heatmap overlay")
gr.Examples(
examples=EXAMPLES,
inputs=[category_dd, image_in],
label="Try an example",
)
run_btn.click(
fn=predict,
inputs=[category_dd, image_in],
outputs=[score_out, verdict_out, heatmap_out],
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())