|
|
""" |
|
|
Deep SVDD Anomaly Detection Model |
|
|
Trained on CIFAR-10, CIFAR-100, and STL-10 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import pickle |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
def __init__(self, in_ch: int, out_ch: int, stride: int = 1): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(out_ch) |
|
|
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(out_ch) |
|
|
|
|
|
self.shortcut = nn.Sequential() |
|
|
if stride != 1 or in_ch != out_ch: |
|
|
self.shortcut = nn.Sequential( |
|
|
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False), |
|
|
nn.BatchNorm2d(out_ch) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
out = F.relu(self.bn1(self.conv1(x))) |
|
|
out = self.bn2(self.conv2(out)) |
|
|
out += self.shortcut(x) |
|
|
return F.relu(out) |
|
|
|
|
|
|
|
|
class DeepSVDDEncoder(nn.Module): |
|
|
def __init__(self, latent_dim: int = 512): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(64) |
|
|
self.layer1 = self._make_layer(64, 128, stride=2) |
|
|
self.layer2 = self._make_layer(128, 256, stride=2) |
|
|
self.layer3 = self._make_layer(256, 512, stride=2) |
|
|
self.layer4 = self._make_layer(512, 512, stride=2) |
|
|
self.fc = nn.Linear(512 * 4 * 4, latent_dim, bias=False) |
|
|
|
|
|
def _make_layer(self, in_ch: int, out_ch: int, stride: int = 1): |
|
|
return nn.Sequential( |
|
|
ResidualBlock(in_ch, out_ch, stride), |
|
|
ResidualBlock(out_ch, out_ch, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.relu(self.bn1(self.conv1(x))) |
|
|
x = self.layer1(x) |
|
|
x = self.layer2(x) |
|
|
x = self.layer3(x) |
|
|
x = self.layer4(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
return self.fc(x) |
|
|
|
|
|
|
|
|
class DeepSVDDAnomalyDetector: |
|
|
""" |
|
|
Deep SVDD Anomaly Detection Model |
|
|
|
|
|
Usage: |
|
|
from model import DeepSVDDAnomalyDetector |
|
|
|
|
|
detector = DeepSVDDAnomalyDetector.from_pretrained('.') |
|
|
score, is_anomaly = detector.predict('image.jpg') |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path, thresholds_path, config_path, device='cuda'): |
|
|
self.device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=self.device) |
|
|
self.latent_dim = checkpoint['latent_dim'] |
|
|
self.center = checkpoint['center'].to(self.device) |
|
|
self.radius = checkpoint['radius'].item() |
|
|
|
|
|
self.encoder = DeepSVDDEncoder(self.latent_dim).to(self.device) |
|
|
self.encoder.load_state_dict(checkpoint['encoder_state_dict']) |
|
|
self.encoder.eval() |
|
|
|
|
|
|
|
|
with open(thresholds_path, 'rb') as f: |
|
|
thresholds = pickle.load(f) |
|
|
|
|
|
self.threshold_95 = thresholds['95th_percentile'] |
|
|
self.threshold_99 = thresholds['99th_percentile'] |
|
|
self.threshold_optimal = thresholds['optimal_f1'] |
|
|
self.threshold = self.threshold_optimal |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((128, 128)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, model_path='.', device='cuda'): |
|
|
"""Load pretrained model from directory or HuggingFace Hub""" |
|
|
model_path = Path(model_path) |
|
|
return cls( |
|
|
model_path=model_path / 'deepsvdd_model.pth', |
|
|
thresholds_path=model_path / 'thresholds.pkl', |
|
|
config_path=model_path / 'config.json', |
|
|
device=device |
|
|
) |
|
|
|
|
|
def set_threshold(self, threshold_type='optimal'): |
|
|
"""Set threshold: 'optimal', '95th', or '99th'""" |
|
|
if threshold_type == 'optimal': |
|
|
self.threshold = self.threshold_optimal |
|
|
elif threshold_type == '95th': |
|
|
self.threshold = self.threshold_95 |
|
|
elif threshold_type == '99th': |
|
|
self.threshold = self.threshold_99 |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, image_path): |
|
|
"""Predict if image is anomaly""" |
|
|
if isinstance(image_path, (str, Path)): |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
else: |
|
|
image = image_path |
|
|
|
|
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
embeddings = self.encoder(image_tensor) |
|
|
score = torch.sum((embeddings - self.center) ** 2, dim=1).item() |
|
|
is_anomaly = score > self.threshold |
|
|
|
|
|
return score, is_anomaly |
|
|
|