Spaces:
Sleeping
Sleeping
File size: 3,938 Bytes
9cf599c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | import os
import torch
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from anomalib.models import Patchcore
from anomalib.data import Folder
from pytorch_lightning import Trainer
# --- Load config ---
CONFIG_PATH = "configs/patchcore_transformers.yaml"
CKPT_PATH = "results/Patchcore/transformers/v7/weights/lightning/model.ckpt"
OUT_MASK_DIR = "api_inference_pred_pipeline"
OUT_FILTERED_DIR = "api_inference_filtered_pipeline"
os.makedirs(OUT_MASK_DIR, exist_ok=True)
os.makedirs(OUT_FILTERED_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load config
config = OmegaConf.load(CONFIG_PATH)
if __name__ == "__main__":
os.makedirs(OUT_MASK_DIR, exist_ok=True)
os.makedirs(OUT_FILTERED_DIR, exist_ok=True)
# Setup datamodule for prediction (use test set)
# Use arguments matching the YAML config and Folder datamodule signature
data_module = Folder(
name=config.data.init_args.name,
root=config.data.init_args.root,
normal_dir=config.data.init_args.normal_dir,
abnormal_dir=config.data.init_args.abnormal_dir,
normal_test_dir=config.data.init_args.normal_test_dir,
train_batch_size=config.data.init_args.train_batch_size,
eval_batch_size=config.data.init_args.eval_batch_size,
num_workers=config.data.init_args.num_workers,
)
data_module.setup()
# Load model
model = Patchcore.load_from_checkpoint(CKPT_PATH, **config.model.init_args)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Inference loop
for batch in data_module.test_dataloader():
img = batch.image.to(device)
fname = batch.image_path[0]
with torch.no_grad():
output = model(img)
# PatchCore returns (anomaly_score, anomaly_map, ...)
if hasattr(output, 'anomaly_map'):
anomaly_map = output.anomaly_map.squeeze().cpu().numpy()
elif isinstance(output, (tuple, list)) and len(output) > 1:
anomaly_map = output[1].squeeze().cpu().numpy()
else:
anomaly_map = None
if anomaly_map is not None:
# Normalize to 0-255 for visualization
norm_map = (255 * (anomaly_map - anomaly_map.min()) / (np.ptp(anomaly_map) + 1e-8)).astype(np.uint8)
# Ensure norm_map is 2D for PIL
if norm_map.ndim > 2:
norm_map = np.squeeze(norm_map)
if norm_map.ndim > 2:
norm_map = norm_map[0]
mask_img = Image.fromarray(norm_map)
out_name = os.path.splitext(os.path.basename(fname))[0] + "_mask.png"
mask_img.save(os.path.join(OUT_MASK_DIR, out_name))
print(f"Saved mask for {fname}")
# Save filtered (masked) part of the original transformer image
orig_img = Image.open(fname).convert("RGB")
# Resize mask to match original image size if needed
if mask_img.size != orig_img.size:
mask_img_resized = mask_img.resize(orig_img.size, resample=Image.BILINEAR)
else:
mask_img_resized = mask_img
# Binarize mask (threshold at 128)
bin_mask = np.array(mask_img_resized) > 128
# Apply mask to original image
orig_np = np.array(orig_img)
filtered_np = np.zeros_like(orig_np)
filtered_np[bin_mask] = orig_np[bin_mask]
filtered_img = Image.fromarray(filtered_np)
filtered_name = os.path.splitext(os.path.basename(fname))[0] + "_filtered.png"
filtered_img.save(os.path.join(OUT_FILTERED_DIR, filtered_name))
print(f"Saved filtered image for {fname}")
else:
print(f"No mask generated for {fname}")
print(f"All masks saved to {OUT_MASK_DIR}")
|