knownmax's picture
Upload folder using huggingface_hub
ed4e653 verified
Raw
History Blame Contribute Delete
9.44 kB
"""
PatchCore model: fit (feature extraction + coreset indexing) and predict (KNN scoring).
No training loop — PatchCore is purely:
fit : one forward pass over normal images → coreset memory bank
predict: KNN distance query for each test patch → anomaly map + image score
"""
import warnings
from pathlib import Path
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from scipy.ndimage import gaussian_filter
from tqdm import tqdm
from feature_extractor import PatchFeatureExtractor
from coreset import subsample_coreset
class PatchCore:
"""
PatchCore anomaly detector.
Args:
backbone : timm model name (default: 'wide_resnet101_2')
coreset_ratio : fraction of patches to keep in memory bank (default: 0.01)
device : torch device string (default: 'cuda')
faiss_gpu : use faiss GPU index (default: True, fallback to CPU)
gaussian_sigma: sigma for anomaly map smoothing (default: 4)
"""
PATCH_GRID = 28 # spatial grid size for 224px input (both axes)
FEAT_DIM = 1536 # layer2 (512) + layer3 (1024)
def __init__(
self,
backbone: str = "wide_resnet101_2",
coreset_ratio: float = 0.01,
device: str = "cuda",
faiss_gpu: bool = True,
gaussian_sigma: float = 4.0,
):
self.backbone = backbone
self.coreset_ratio = coreset_ratio
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.faiss_gpu = faiss_gpu
self.gaussian_sigma = gaussian_sigma
self.extractor = PatchFeatureExtractor(backbone=backbone, pretrained=True)
self.extractor = self.extractor.to(self.device)
self.memory_bank: torch.Tensor = None # [M, 1536]
self._faiss_index = None
self._index_backend = None
# ------------------------------------------------------------------
# Fit
# ------------------------------------------------------------------
def fit(self, train_loader) -> None:
"""
Extract patch features from all normal training images,
run coreset subsampling, build faiss index.
"""
print("[PatchCore] Extracting features from training images …")
all_features = []
self.extractor.eval()
with torch.no_grad():
for batch in tqdm(train_loader, desc=" Feature extraction", leave=False):
if isinstance(batch, (list, tuple)):
imgs = batch[0]
else:
imgs = batch
imgs = imgs.to(self.device)
feats = self.extractor.extract_patch_features(imgs) # [B*784, 1536]
all_features.append(feats.cpu())
all_features = torch.cat(all_features, dim=0) # [N_total, 1536]
print(f"[PatchCore] Total patches before coreset: {len(all_features):,}")
# Move to GPU for fast coreset computation
all_features = all_features.to(self.device)
print(f"[PatchCore] Running coreset subsampling (ratio={self.coreset_ratio}) …")
self.memory_bank = subsample_coreset(all_features, self.coreset_ratio)
print(f"[PatchCore] Memory bank size after coreset: {len(self.memory_bank):,} "
f"({self.memory_bank.element_size() * self.memory_bank.numel() / 1e6:.1f} MB)")
self._build_faiss_index()
def _build_faiss_index(self) -> None:
"""Build KNN backend from the memory bank (faiss preferred, torch fallback)."""
bank_np = self.memory_bank.cpu().numpy().astype(np.float32)
try:
import faiss
d = bank_np.shape[1]
index_flat = faiss.IndexFlatL2(d)
if self.faiss_gpu and torch.cuda.is_available():
try:
res = faiss.StandardGpuResources()
self._faiss_index = faiss.index_cpu_to_gpu(res, 0, index_flat)
self._index_backend = "faiss-gpu"
print("[PatchCore] Using faiss GPU index.")
except Exception as e:
warnings.warn(f"[PatchCore] faiss GPU index failed ({e}); falling back to faiss CPU.")
self._faiss_index = index_flat
self._index_backend = "faiss-cpu"
else:
self._faiss_index = index_flat
self._index_backend = "faiss-cpu"
self._faiss_index.add(bank_np)
except Exception as e:
warnings.warn(
"[PatchCore] faiss is unavailable/incompatible; "
f"falling back to torch KNN search. Details: {e}"
)
# Keep a contiguous float bank on the configured device for fallback KNN.
self.memory_bank = self.memory_bank.float().contiguous().to(self.device)
self._faiss_index = None
self._index_backend = "torch"
def _search_knn(self, feats_np: np.ndarray, k: int = 1) -> Tuple[np.ndarray, np.ndarray]:
"""Return (squared_l2_distances, indices) for nearest neighbours."""
if self._index_backend in ("faiss-gpu", "faiss-cpu"):
return self._faiss_index.search(feats_np, k=k)
# Torch fallback for environments where faiss import fails (e.g., NumPy ABI mismatch).
query = torch.from_numpy(feats_np).to(self.device, dtype=torch.float32)
bank = self.memory_bank
# Chunking keeps peak memory usage bounded for larger memory banks.
chunk_size = 1024
dist_chunks = []
idx_chunks = []
with torch.no_grad():
for start in range(0, query.shape[0], chunk_size):
q = query[start:start + chunk_size] # [Q, D]
d2 = torch.sum((q[:, None, :] - bank[None, :, :]) ** 2, dim=-1) # [Q, M]
vals, idxs = torch.topk(d2, k=k, dim=1, largest=False)
dist_chunks.append(vals)
idx_chunks.append(idxs)
distances = torch.cat(dist_chunks, dim=0).detach().cpu().numpy().astype(np.float32)
indices = torch.cat(idx_chunks, dim=0).detach().cpu().numpy().astype(np.int64)
return distances, indices
# ------------------------------------------------------------------
# Predict
# ------------------------------------------------------------------
@torch.no_grad()
def predict(self, image_tensor: torch.Tensor) -> Tuple[float, np.ndarray]:
"""
Compute anomaly score and pixel-level anomaly map for a single image.
Args:
image_tensor: [1, 3, 224, 224] normalised image tensor
Returns:
image_score : float — max patch distance (image-level anomaly score)
anomaly_map : np.ndarray [224, 224] — smoothed, upsampled patch distance map
"""
image_tensor = image_tensor.to(self.device)
# Extract patch features: [1*784, 1536]
feats = self.extractor.extract_patch_features(image_tensor)
feats_np = feats.cpu().numpy().astype(np.float32)
# KNN distance query (k=1 nearest neighbour)
distances, _ = self._search_knn(feats_np, k=1) # [N_patches, 1]
patch_scores = distances[:, 0] # [N_patches] squared L2 distances
# Compute actual grid size from number of patches (robust to different feature extractors)
num_patches = len(patch_scores)
patch_grid = int(np.sqrt(num_patches))
assert patch_grid * patch_grid == num_patches, \
f"Non-square patch grid: {num_patches} patches (expected {patch_grid}²)"
# Reshape to spatial grid
score_map = patch_scores.reshape(patch_grid, patch_grid)
# Upsample to 224×224 via bilinear interpolation
score_tensor = torch.from_numpy(score_map).unsqueeze(0).unsqueeze(0) # [1,1,G,G]
score_upsampled = F.interpolate(
score_tensor, size=(224, 224), mode="bilinear", align_corners=False
).squeeze().numpy() # [224, 224]
# Gaussian smoothing
anomaly_map = gaussian_filter(score_upsampled, sigma=self.gaussian_sigma)
# Image-level score: max patch distance
image_score = float(patch_scores.max())
return image_score, anomaly_map
# ------------------------------------------------------------------
# Save / Load
# ------------------------------------------------------------------
def save(self, path: str) -> None:
"""Serialize memory bank and config to a .pt file."""
torch.save(
{
"memory_bank": self.memory_bank.cpu(),
"backbone": self.backbone,
"coreset_ratio": self.coreset_ratio,
"gaussian_sigma": self.gaussian_sigma,
},
path,
)
print(f"[PatchCore] Model saved to {path}")
def load(self, path: str) -> None:
"""Load memory bank and config from a .pt file, rebuild faiss index."""
ckpt = torch.load(path, map_location="cpu")
self.memory_bank = ckpt["memory_bank"].to(self.device)
self.backbone = ckpt["backbone"]
self.coreset_ratio = ckpt["coreset_ratio"]
self.gaussian_sigma = ckpt["gaussian_sigma"]
self._build_faiss_index()
print(f"[PatchCore] Model loaded from {path} "
f"(memory bank: {len(self.memory_bank):,} vectors)")