anomaly-detection-api / scripts /patchcore_api_inference.py
Senum2001
Deploy Anomaly Detection API
9cf599c
import os
import torch
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from anomalib.models import Patchcore
from anomalib.data import Folder
from pytorch_lightning import Trainer
# --- Load config ---
CONFIG_PATH = "configs/patchcore_transformers.yaml"
CKPT_PATH = "results/Patchcore/transformers/v7/weights/lightning/model.ckpt"
OUT_MASK_DIR = "api_inference_pred_pipeline"
OUT_FILTERED_DIR = "api_inference_filtered_pipeline"
os.makedirs(OUT_MASK_DIR, exist_ok=True)
os.makedirs(OUT_FILTERED_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load config
config = OmegaConf.load(CONFIG_PATH)
if __name__ == "__main__":
os.makedirs(OUT_MASK_DIR, exist_ok=True)
os.makedirs(OUT_FILTERED_DIR, exist_ok=True)
# Setup datamodule for prediction (use test set)
# Use arguments matching the YAML config and Folder datamodule signature
data_module = Folder(
name=config.data.init_args.name,
root=config.data.init_args.root,
normal_dir=config.data.init_args.normal_dir,
abnormal_dir=config.data.init_args.abnormal_dir,
normal_test_dir=config.data.init_args.normal_test_dir,
train_batch_size=config.data.init_args.train_batch_size,
eval_batch_size=config.data.init_args.eval_batch_size,
num_workers=config.data.init_args.num_workers,
)
data_module.setup()
# Load model
model = Patchcore.load_from_checkpoint(CKPT_PATH, **config.model.init_args)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Inference loop
for batch in data_module.test_dataloader():
img = batch.image.to(device)
fname = batch.image_path[0]
with torch.no_grad():
output = model(img)
# PatchCore returns (anomaly_score, anomaly_map, ...)
if hasattr(output, 'anomaly_map'):
anomaly_map = output.anomaly_map.squeeze().cpu().numpy()
elif isinstance(output, (tuple, list)) and len(output) > 1:
anomaly_map = output[1].squeeze().cpu().numpy()
else:
anomaly_map = None
if anomaly_map is not None:
# Normalize to 0-255 for visualization
norm_map = (255 * (anomaly_map - anomaly_map.min()) / (np.ptp(anomaly_map) + 1e-8)).astype(np.uint8)
# Ensure norm_map is 2D for PIL
if norm_map.ndim > 2:
norm_map = np.squeeze(norm_map)
if norm_map.ndim > 2:
norm_map = norm_map[0]
mask_img = Image.fromarray(norm_map)
out_name = os.path.splitext(os.path.basename(fname))[0] + "_mask.png"
mask_img.save(os.path.join(OUT_MASK_DIR, out_name))
print(f"Saved mask for {fname}")
# Save filtered (masked) part of the original transformer image
orig_img = Image.open(fname).convert("RGB")
# Resize mask to match original image size if needed
if mask_img.size != orig_img.size:
mask_img_resized = mask_img.resize(orig_img.size, resample=Image.BILINEAR)
else:
mask_img_resized = mask_img
# Binarize mask (threshold at 128)
bin_mask = np.array(mask_img_resized) > 128
# Apply mask to original image
orig_np = np.array(orig_img)
filtered_np = np.zeros_like(orig_np)
filtered_np[bin_mask] = orig_np[bin_mask]
filtered_img = Image.fromarray(filtered_np)
filtered_name = os.path.splitext(os.path.basename(fname))[0] + "_filtered.png"
filtered_img.save(os.path.join(OUT_FILTERED_DIR, filtered_name))
print(f"Saved filtered image for {fname}")
else:
print(f"No mask generated for {fname}")
print(f"All masks saved to {OUT_MASK_DIR}")