Spaces:
Running
Running
| import os | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import glob | |
| from sklearn.neighbors import NearestNeighbors | |
| # ----------------- Config ----------------- | |
| BACKBONE_NAME = "efficientnet_b4" | |
| PRETRAINED_WEIGHTS = models.EfficientNet_B4_Weights.IMAGENET1K_V1 | |
| FEATURE_LAYER_HOOK = "features[5]" | |
| IMAGE_SIZE = 256 | |
| MEMORY_BANK_PATH = "memory_bank.npy" | |
| class AnomalyDetector: | |
| """ | |
| Encapsulates the anomaly detection model, data, and prediction logic. | |
| """ | |
| def __init__(self, memory_bank_path: str): | |
| """ | |
| Initializes the detector by loading the model, transforms, and memory bank. | |
| """ | |
| print("Initializing AnomalyDetector...") | |
| self.model, self.transform = self._get_model_and_transforms() | |
| print("Model and transforms loaded.") | |
| memory_bank = np.load(memory_bank_path) | |
| print(f"Memory bank loaded from {memory_bank_path} with shape {memory_bank.shape}.") | |
| self.knn = NearestNeighbors(n_neighbors=3, algorithm='ball_tree', metric='minkowski', p=2.0) | |
| self.knn.fit(memory_bank) | |
| print("k-NN detector fitted.") | |
| def _get_model_and_transforms(self): | |
| model = models.efficientnet_b4(weights=PRETRAINED_WEIGHTS) | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| return model, transform | |
| def _extract_features(self, image_tensor: torch.Tensor) -> torch.Tensor: | |
| features = None | |
| def hook(_, __, output): | |
| nonlocal features | |
| features = output | |
| handle = eval(f"self.model.{FEATURE_LAYER_HOOK}").register_forward_hook(hook) | |
| with torch.no_grad(): | |
| self.model(image_tensor.unsqueeze(0)) | |
| handle.remove() | |
| return features # [1, C, H, W] | |
| def predict(self, image_path: str): | |
| """ | |
| Processes an image from a file path and returns the original, a heatmap, and an overlay. | |
| """ | |
| image = Image.open(image_path) | |
| # 1. Pre-process image | |
| img_rgb = image.convert("RGB") | |
| img_resized = img_rgb.resize((IMAGE_SIZE, IMAGE_SIZE)) | |
| image_tensor = self.transform(img_rgb) | |
| # 2. Extract patch embeddings | |
| feature_map = self._extract_features(image_tensor) | |
| h, w = feature_map.shape[2], feature_map.shape[3] | |
| embedding = feature_map.squeeze(0).permute(1, 2, 0).reshape(-1, feature_map.shape[1]) | |
| embedding = embedding.numpy() | |
| # 3. Get anomaly scores from k-NN | |
| distances, _ = self.knn.kneighbors(embedding) | |
| patch_scores = np.mean(distances, axis=1) | |
| anomaly_map = patch_scores.reshape(h, w) | |
| # 4. Prepare anomaly map for visualization with a fixed range | |
| # Resize the raw score map | |
| anomaly_map_resized = cv2.resize(anomaly_map, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_LINEAR) | |
| # Clip scores to the fixed range [vmin, vmax] | |
| vmin, vmax = 0, 200 | |
| clipped_map = np.clip(anomaly_map_resized, vmin, vmax) | |
| # Scale the clipped scores to the 0-255 range for the colormap | |
| scaled_map = 255 * (clipped_map - vmin) / (vmax - vmin) | |
| # 5. Create visualizations | |
| heatmap_rgb = cv2.applyColorMap(scaled_map.astype(np.uint8), cv2.COLORMAP_JET) | |
| heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB) | |
| overlay = cv2.addWeighted(np.array(img_resized), 0.6, heatmap_rgb, 0.4, 0) | |
| return img_resized, heatmap_rgb, overlay | |
| DESCRIPTION = """ | |
| **PatchCore-style Anomaly Detection (EfficientNet-B4, k-NN on patch embeddings)** | |
| - Upload an image (PNG/JPG). | |
| - App returns the resized original, anomaly heatmap, and overlay. | |
| - This app requires a `memory_bank.npy` file in the repository root. | |
| """ | |
| examples = [] | |
| if os.path.isdir("examples"): | |
| for f in os.listdir("examples"): | |
| if f.lower().endswith((".png", ".jpg", ".jpeg")): | |
| examples.append(os.path.join("examples", f)) | |
| if __name__ == "__main__": | |
| if not os.path.exists(MEMORY_BANK_PATH): | |
| print(f"FATAL: `{MEMORY_BANK_PATH}` not found.") | |
| exit() | |
| # Create a single instance of the detector. This performs the one-time setup. | |
| detector = AnomalyDetector(MEMORY_BANK_PATH) | |
| demo = gr.Interface( | |
| fn=detector.predict, | |
| inputs=gr.Dropdown( | |
| choices=glob.glob("dataset/**/*.png", recursive=True), | |
| label="Select Test Image", | |
| info="Select a PNG file from the dataset directory." | |
| ), | |
| outputs=[ | |
| gr.Image(type="pil", label="Original (Resized)"), | |
| gr.Image(type="pil", label="Anomaly Heatmap"), | |
| gr.Image(type="pil", label="Overlay"), | |
| ], | |
| title="Anomaly Detection (EfficientNet-B4 + kNN)", | |
| description=DESCRIPTION, | |
| allow_flagging="never", | |
| examples=examples if examples else None | |
| ) | |
| demo.queue().launch() | |