File size: 9,437 Bytes
ed4e653 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | """
PatchCore model: fit (feature extraction + coreset indexing) and predict (KNN scoring).
No training loop — PatchCore is purely:
fit : one forward pass over normal images → coreset memory bank
predict: KNN distance query for each test patch → anomaly map + image score
"""
import warnings
from pathlib import Path
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from scipy.ndimage import gaussian_filter
from tqdm import tqdm
from feature_extractor import PatchFeatureExtractor
from coreset import subsample_coreset
class PatchCore:
"""
PatchCore anomaly detector.
Args:
backbone : timm model name (default: 'wide_resnet101_2')
coreset_ratio : fraction of patches to keep in memory bank (default: 0.01)
device : torch device string (default: 'cuda')
faiss_gpu : use faiss GPU index (default: True, fallback to CPU)
gaussian_sigma: sigma for anomaly map smoothing (default: 4)
"""
PATCH_GRID = 28 # spatial grid size for 224px input (both axes)
FEAT_DIM = 1536 # layer2 (512) + layer3 (1024)
def __init__(
self,
backbone: str = "wide_resnet101_2",
coreset_ratio: float = 0.01,
device: str = "cuda",
faiss_gpu: bool = True,
gaussian_sigma: float = 4.0,
):
self.backbone = backbone
self.coreset_ratio = coreset_ratio
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.faiss_gpu = faiss_gpu
self.gaussian_sigma = gaussian_sigma
self.extractor = PatchFeatureExtractor(backbone=backbone, pretrained=True)
self.extractor = self.extractor.to(self.device)
self.memory_bank: torch.Tensor = None # [M, 1536]
self._faiss_index = None
self._index_backend = None
# ------------------------------------------------------------------
# Fit
# ------------------------------------------------------------------
def fit(self, train_loader) -> None:
"""
Extract patch features from all normal training images,
run coreset subsampling, build faiss index.
"""
print("[PatchCore] Extracting features from training images …")
all_features = []
self.extractor.eval()
with torch.no_grad():
for batch in tqdm(train_loader, desc=" Feature extraction", leave=False):
if isinstance(batch, (list, tuple)):
imgs = batch[0]
else:
imgs = batch
imgs = imgs.to(self.device)
feats = self.extractor.extract_patch_features(imgs) # [B*784, 1536]
all_features.append(feats.cpu())
all_features = torch.cat(all_features, dim=0) # [N_total, 1536]
print(f"[PatchCore] Total patches before coreset: {len(all_features):,}")
# Move to GPU for fast coreset computation
all_features = all_features.to(self.device)
print(f"[PatchCore] Running coreset subsampling (ratio={self.coreset_ratio}) …")
self.memory_bank = subsample_coreset(all_features, self.coreset_ratio)
print(f"[PatchCore] Memory bank size after coreset: {len(self.memory_bank):,} "
f"({self.memory_bank.element_size() * self.memory_bank.numel() / 1e6:.1f} MB)")
self._build_faiss_index()
def _build_faiss_index(self) -> None:
"""Build KNN backend from the memory bank (faiss preferred, torch fallback)."""
bank_np = self.memory_bank.cpu().numpy().astype(np.float32)
try:
import faiss
d = bank_np.shape[1]
index_flat = faiss.IndexFlatL2(d)
if self.faiss_gpu and torch.cuda.is_available():
try:
res = faiss.StandardGpuResources()
self._faiss_index = faiss.index_cpu_to_gpu(res, 0, index_flat)
self._index_backend = "faiss-gpu"
print("[PatchCore] Using faiss GPU index.")
except Exception as e:
warnings.warn(f"[PatchCore] faiss GPU index failed ({e}); falling back to faiss CPU.")
self._faiss_index = index_flat
self._index_backend = "faiss-cpu"
else:
self._faiss_index = index_flat
self._index_backend = "faiss-cpu"
self._faiss_index.add(bank_np)
except Exception as e:
warnings.warn(
"[PatchCore] faiss is unavailable/incompatible; "
f"falling back to torch KNN search. Details: {e}"
)
# Keep a contiguous float bank on the configured device for fallback KNN.
self.memory_bank = self.memory_bank.float().contiguous().to(self.device)
self._faiss_index = None
self._index_backend = "torch"
def _search_knn(self, feats_np: np.ndarray, k: int = 1) -> Tuple[np.ndarray, np.ndarray]:
"""Return (squared_l2_distances, indices) for nearest neighbours."""
if self._index_backend in ("faiss-gpu", "faiss-cpu"):
return self._faiss_index.search(feats_np, k=k)
# Torch fallback for environments where faiss import fails (e.g., NumPy ABI mismatch).
query = torch.from_numpy(feats_np).to(self.device, dtype=torch.float32)
bank = self.memory_bank
# Chunking keeps peak memory usage bounded for larger memory banks.
chunk_size = 1024
dist_chunks = []
idx_chunks = []
with torch.no_grad():
for start in range(0, query.shape[0], chunk_size):
q = query[start:start + chunk_size] # [Q, D]
d2 = torch.sum((q[:, None, :] - bank[None, :, :]) ** 2, dim=-1) # [Q, M]
vals, idxs = torch.topk(d2, k=k, dim=1, largest=False)
dist_chunks.append(vals)
idx_chunks.append(idxs)
distances = torch.cat(dist_chunks, dim=0).detach().cpu().numpy().astype(np.float32)
indices = torch.cat(idx_chunks, dim=0).detach().cpu().numpy().astype(np.int64)
return distances, indices
# ------------------------------------------------------------------
# Predict
# ------------------------------------------------------------------
@torch.no_grad()
def predict(self, image_tensor: torch.Tensor) -> Tuple[float, np.ndarray]:
"""
Compute anomaly score and pixel-level anomaly map for a single image.
Args:
image_tensor: [1, 3, 224, 224] normalised image tensor
Returns:
image_score : float — max patch distance (image-level anomaly score)
anomaly_map : np.ndarray [224, 224] — smoothed, upsampled patch distance map
"""
image_tensor = image_tensor.to(self.device)
# Extract patch features: [1*784, 1536]
feats = self.extractor.extract_patch_features(image_tensor)
feats_np = feats.cpu().numpy().astype(np.float32)
# KNN distance query (k=1 nearest neighbour)
distances, _ = self._search_knn(feats_np, k=1) # [N_patches, 1]
patch_scores = distances[:, 0] # [N_patches] squared L2 distances
# Compute actual grid size from number of patches (robust to different feature extractors)
num_patches = len(patch_scores)
patch_grid = int(np.sqrt(num_patches))
assert patch_grid * patch_grid == num_patches, \
f"Non-square patch grid: {num_patches} patches (expected {patch_grid}²)"
# Reshape to spatial grid
score_map = patch_scores.reshape(patch_grid, patch_grid)
# Upsample to 224×224 via bilinear interpolation
score_tensor = torch.from_numpy(score_map).unsqueeze(0).unsqueeze(0) # [1,1,G,G]
score_upsampled = F.interpolate(
score_tensor, size=(224, 224), mode="bilinear", align_corners=False
).squeeze().numpy() # [224, 224]
# Gaussian smoothing
anomaly_map = gaussian_filter(score_upsampled, sigma=self.gaussian_sigma)
# Image-level score: max patch distance
image_score = float(patch_scores.max())
return image_score, anomaly_map
# ------------------------------------------------------------------
# Save / Load
# ------------------------------------------------------------------
def save(self, path: str) -> None:
"""Serialize memory bank and config to a .pt file."""
torch.save(
{
"memory_bank": self.memory_bank.cpu(),
"backbone": self.backbone,
"coreset_ratio": self.coreset_ratio,
"gaussian_sigma": self.gaussian_sigma,
},
path,
)
print(f"[PatchCore] Model saved to {path}")
def load(self, path: str) -> None:
"""Load memory bank and config from a .pt file, rebuild faiss index."""
ckpt = torch.load(path, map_location="cpu")
self.memory_bank = ckpt["memory_bank"].to(self.device)
self.backbone = ckpt["backbone"]
self.coreset_ratio = ckpt["coreset_ratio"]
self.gaussian_sigma = ckpt["gaussian_sigma"]
self._build_faiss_index()
print(f"[PatchCore] Model loaded from {path} "
f"(memory bank: {len(self.memory_bank):,} vectors)")
|