Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from omegaconf import OmegaConf | |
| from anomalib.models import Patchcore | |
| from anomalib.data import Folder | |
| from pytorch_lightning import Trainer | |
| # --- Load config --- | |
| CONFIG_PATH = "configs/patchcore_transformers.yaml" | |
| CKPT_PATH = "results/Patchcore/transformers/v7/weights/lightning/model.ckpt" | |
| OUT_MASK_DIR = "api_inference_pred_pipeline" | |
| OUT_FILTERED_DIR = "api_inference_filtered_pipeline" | |
| os.makedirs(OUT_MASK_DIR, exist_ok=True) | |
| os.makedirs(OUT_FILTERED_DIR, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load config | |
| config = OmegaConf.load(CONFIG_PATH) | |
| if __name__ == "__main__": | |
| os.makedirs(OUT_MASK_DIR, exist_ok=True) | |
| os.makedirs(OUT_FILTERED_DIR, exist_ok=True) | |
| # Setup datamodule for prediction (use test set) | |
| # Use arguments matching the YAML config and Folder datamodule signature | |
| data_module = Folder( | |
| name=config.data.init_args.name, | |
| root=config.data.init_args.root, | |
| normal_dir=config.data.init_args.normal_dir, | |
| abnormal_dir=config.data.init_args.abnormal_dir, | |
| normal_test_dir=config.data.init_args.normal_test_dir, | |
| train_batch_size=config.data.init_args.train_batch_size, | |
| eval_batch_size=config.data.init_args.eval_batch_size, | |
| num_workers=config.data.init_args.num_workers, | |
| ) | |
| data_module.setup() | |
| # Load model | |
| model = Patchcore.load_from_checkpoint(CKPT_PATH, **config.model.init_args) | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| # Inference loop | |
| for batch in data_module.test_dataloader(): | |
| img = batch.image.to(device) | |
| fname = batch.image_path[0] | |
| with torch.no_grad(): | |
| output = model(img) | |
| # PatchCore returns (anomaly_score, anomaly_map, ...) | |
| if hasattr(output, 'anomaly_map'): | |
| anomaly_map = output.anomaly_map.squeeze().cpu().numpy() | |
| elif isinstance(output, (tuple, list)) and len(output) > 1: | |
| anomaly_map = output[1].squeeze().cpu().numpy() | |
| else: | |
| anomaly_map = None | |
| if anomaly_map is not None: | |
| # Normalize to 0-255 for visualization | |
| norm_map = (255 * (anomaly_map - anomaly_map.min()) / (np.ptp(anomaly_map) + 1e-8)).astype(np.uint8) | |
| # Ensure norm_map is 2D for PIL | |
| if norm_map.ndim > 2: | |
| norm_map = np.squeeze(norm_map) | |
| if norm_map.ndim > 2: | |
| norm_map = norm_map[0] | |
| mask_img = Image.fromarray(norm_map) | |
| out_name = os.path.splitext(os.path.basename(fname))[0] + "_mask.png" | |
| mask_img.save(os.path.join(OUT_MASK_DIR, out_name)) | |
| print(f"Saved mask for {fname}") | |
| # Save filtered (masked) part of the original transformer image | |
| orig_img = Image.open(fname).convert("RGB") | |
| # Resize mask to match original image size if needed | |
| if mask_img.size != orig_img.size: | |
| mask_img_resized = mask_img.resize(orig_img.size, resample=Image.BILINEAR) | |
| else: | |
| mask_img_resized = mask_img | |
| # Binarize mask (threshold at 128) | |
| bin_mask = np.array(mask_img_resized) > 128 | |
| # Apply mask to original image | |
| orig_np = np.array(orig_img) | |
| filtered_np = np.zeros_like(orig_np) | |
| filtered_np[bin_mask] = orig_np[bin_mask] | |
| filtered_img = Image.fromarray(filtered_np) | |
| filtered_name = os.path.splitext(os.path.basename(fname))[0] + "_filtered.png" | |
| filtered_img.save(os.path.join(OUT_FILTERED_DIR, filtered_name)) | |
| print(f"Saved filtered image for {fname}") | |
| else: | |
| print(f"No mask generated for {fname}") | |
| print(f"All masks saved to {OUT_MASK_DIR}") | |