duchieuvn
init project
2d020e8
import os
import cv2
import gradio as gr
import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import glob
from sklearn.neighbors import NearestNeighbors
# ----------------- Config -----------------
BACKBONE_NAME = "efficientnet_b4"
PRETRAINED_WEIGHTS = models.EfficientNet_B4_Weights.IMAGENET1K_V1
FEATURE_LAYER_HOOK = "features[5]"
IMAGE_SIZE = 256
MEMORY_BANK_PATH = "memory_bank.npy"
class AnomalyDetector:
"""
Encapsulates the anomaly detection model, data, and prediction logic.
"""
def __init__(self, memory_bank_path: str):
"""
Initializes the detector by loading the model, transforms, and memory bank.
"""
print("Initializing AnomalyDetector...")
self.model, self.transform = self._get_model_and_transforms()
print("Model and transforms loaded.")
memory_bank = np.load(memory_bank_path)
print(f"Memory bank loaded from {memory_bank_path} with shape {memory_bank.shape}.")
self.knn = NearestNeighbors(n_neighbors=3, algorithm='ball_tree', metric='minkowski', p=2.0)
self.knn.fit(memory_bank)
print("k-NN detector fitted.")
def _get_model_and_transforms(self):
model = models.efficientnet_b4(weights=PRETRAINED_WEIGHTS)
for p in model.parameters():
p.requires_grad = False
model.eval()
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
return model, transform
def _extract_features(self, image_tensor: torch.Tensor) -> torch.Tensor:
features = None
def hook(_, __, output):
nonlocal features
features = output
handle = eval(f"self.model.{FEATURE_LAYER_HOOK}").register_forward_hook(hook)
with torch.no_grad():
self.model(image_tensor.unsqueeze(0))
handle.remove()
return features # [1, C, H, W]
def predict(self, image_path: str):
"""
Processes an image from a file path and returns the original, a heatmap, and an overlay.
"""
image = Image.open(image_path)
# 1. Pre-process image
img_rgb = image.convert("RGB")
img_resized = img_rgb.resize((IMAGE_SIZE, IMAGE_SIZE))
image_tensor = self.transform(img_rgb)
# 2. Extract patch embeddings
feature_map = self._extract_features(image_tensor)
h, w = feature_map.shape[2], feature_map.shape[3]
embedding = feature_map.squeeze(0).permute(1, 2, 0).reshape(-1, feature_map.shape[1])
embedding = embedding.numpy()
# 3. Get anomaly scores from k-NN
distances, _ = self.knn.kneighbors(embedding)
patch_scores = np.mean(distances, axis=1)
anomaly_map = patch_scores.reshape(h, w)
# 4. Prepare anomaly map for visualization with a fixed range
# Resize the raw score map
anomaly_map_resized = cv2.resize(anomaly_map, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_LINEAR)
# Clip scores to the fixed range [vmin, vmax]
vmin, vmax = 0, 200
clipped_map = np.clip(anomaly_map_resized, vmin, vmax)
# Scale the clipped scores to the 0-255 range for the colormap
scaled_map = 255 * (clipped_map - vmin) / (vmax - vmin)
# 5. Create visualizations
heatmap_rgb = cv2.applyColorMap(scaled_map.astype(np.uint8), cv2.COLORMAP_JET)
heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
overlay = cv2.addWeighted(np.array(img_resized), 0.6, heatmap_rgb, 0.4, 0)
return img_resized, heatmap_rgb, overlay
DESCRIPTION = """
**PatchCore-style Anomaly Detection (EfficientNet-B4, k-NN on patch embeddings)**
- Upload an image (PNG/JPG).
- App returns the resized original, anomaly heatmap, and overlay.
- This app requires a `memory_bank.npy` file in the repository root.
"""
examples = []
if os.path.isdir("examples"):
for f in os.listdir("examples"):
if f.lower().endswith((".png", ".jpg", ".jpeg")):
examples.append(os.path.join("examples", f))
if __name__ == "__main__":
if not os.path.exists(MEMORY_BANK_PATH):
print(f"FATAL: `{MEMORY_BANK_PATH}` not found.")
exit()
# Create a single instance of the detector. This performs the one-time setup.
detector = AnomalyDetector(MEMORY_BANK_PATH)
demo = gr.Interface(
fn=detector.predict,
inputs=gr.Dropdown(
choices=glob.glob("dataset/**/*.png", recursive=True),
label="Select Test Image",
info="Select a PNG file from the dataset directory."
),
outputs=[
gr.Image(type="pil", label="Original (Resized)"),
gr.Image(type="pil", label="Anomaly Heatmap"),
gr.Image(type="pil", label="Overlay"),
],
title="Anomaly Detection (EfficientNet-B4 + kNN)",
description=DESCRIPTION,
allow_flagging="never",
examples=examples if examples else None
)
demo.queue().launch()