import os import cv2 import gradio as gr import numpy as np import torch import torchvision.models as models import torchvision.transforms as transforms from PIL import Image import glob from sklearn.neighbors import NearestNeighbors # ----------------- Config ----------------- BACKBONE_NAME = "efficientnet_b4" PRETRAINED_WEIGHTS = models.EfficientNet_B4_Weights.IMAGENET1K_V1 FEATURE_LAYER_HOOK = "features[5]" IMAGE_SIZE = 256 MEMORY_BANK_PATH = "memory_bank.npy" class AnomalyDetector: """ Encapsulates the anomaly detection model, data, and prediction logic. """ def __init__(self, memory_bank_path: str): """ Initializes the detector by loading the model, transforms, and memory bank. """ print("Initializing AnomalyDetector...") self.model, self.transform = self._get_model_and_transforms() print("Model and transforms loaded.") memory_bank = np.load(memory_bank_path) print(f"Memory bank loaded from {memory_bank_path} with shape {memory_bank.shape}.") self.knn = NearestNeighbors(n_neighbors=3, algorithm='ball_tree', metric='minkowski', p=2.0) self.knn.fit(memory_bank) print("k-NN detector fitted.") def _get_model_and_transforms(self): model = models.efficientnet_b4(weights=PRETRAINED_WEIGHTS) for p in model.parameters(): p.requires_grad = False model.eval() transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return model, transform def _extract_features(self, image_tensor: torch.Tensor) -> torch.Tensor: features = None def hook(_, __, output): nonlocal features features = output handle = eval(f"self.model.{FEATURE_LAYER_HOOK}").register_forward_hook(hook) with torch.no_grad(): self.model(image_tensor.unsqueeze(0)) handle.remove() return features # [1, C, H, W] def predict(self, image_path: str): """ Processes an image from a file path and returns the original, a heatmap, and an overlay. """ image = Image.open(image_path) # 1. Pre-process image img_rgb = image.convert("RGB") img_resized = img_rgb.resize((IMAGE_SIZE, IMAGE_SIZE)) image_tensor = self.transform(img_rgb) # 2. Extract patch embeddings feature_map = self._extract_features(image_tensor) h, w = feature_map.shape[2], feature_map.shape[3] embedding = feature_map.squeeze(0).permute(1, 2, 0).reshape(-1, feature_map.shape[1]) embedding = embedding.numpy() # 3. Get anomaly scores from k-NN distances, _ = self.knn.kneighbors(embedding) patch_scores = np.mean(distances, axis=1) anomaly_map = patch_scores.reshape(h, w) # 4. Prepare anomaly map for visualization with a fixed range # Resize the raw score map anomaly_map_resized = cv2.resize(anomaly_map, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_LINEAR) # Clip scores to the fixed range [vmin, vmax] vmin, vmax = 0, 200 clipped_map = np.clip(anomaly_map_resized, vmin, vmax) # Scale the clipped scores to the 0-255 range for the colormap scaled_map = 255 * (clipped_map - vmin) / (vmax - vmin) # 5. Create visualizations heatmap_rgb = cv2.applyColorMap(scaled_map.astype(np.uint8), cv2.COLORMAP_JET) heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB) overlay = cv2.addWeighted(np.array(img_resized), 0.6, heatmap_rgb, 0.4, 0) return img_resized, heatmap_rgb, overlay DESCRIPTION = """ **PatchCore-style Anomaly Detection (EfficientNet-B4, k-NN on patch embeddings)** - Upload an image (PNG/JPG). - App returns the resized original, anomaly heatmap, and overlay. - This app requires a `memory_bank.npy` file in the repository root. """ examples = [] if os.path.isdir("examples"): for f in os.listdir("examples"): if f.lower().endswith((".png", ".jpg", ".jpeg")): examples.append(os.path.join("examples", f)) if __name__ == "__main__": if not os.path.exists(MEMORY_BANK_PATH): print(f"FATAL: `{MEMORY_BANK_PATH}` not found.") exit() # Create a single instance of the detector. This performs the one-time setup. detector = AnomalyDetector(MEMORY_BANK_PATH) demo = gr.Interface( fn=detector.predict, inputs=gr.Dropdown( choices=glob.glob("dataset/**/*.png", recursive=True), label="Select Test Image", info="Select a PNG file from the dataset directory." ), outputs=[ gr.Image(type="pil", label="Original (Resized)"), gr.Image(type="pil", label="Anomaly Heatmap"), gr.Image(type="pil", label="Overlay"), ], title="Anomaly Detection (EfficientNet-B4 + kNN)", description=DESCRIPTION, allow_flagging="never", examples=examples if examples else None ) demo.queue().launch()