File size: 5,005 Bytes
0e9dcc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""
Deep SVDD Anomaly Detection Model
Trained on CIFAR-10, CIFAR-100, and STL-10
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import pickle
import json
from pathlib import Path
class ResidualBlock(nn.Module):
def __init__(self, in_ch: int, out_ch: int, stride: int = 1):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch)
self.shortcut = nn.Sequential()
if stride != 1 or in_ch != out_ch:
self.shortcut = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
nn.BatchNorm2d(out_ch)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return F.relu(out)
class DeepSVDDEncoder(nn.Module):
def __init__(self, latent_dim: int = 512):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(64, 128, stride=2)
self.layer2 = self._make_layer(128, 256, stride=2)
self.layer3 = self._make_layer(256, 512, stride=2)
self.layer4 = self._make_layer(512, 512, stride=2)
self.fc = nn.Linear(512 * 4 * 4, latent_dim, bias=False)
def _make_layer(self, in_ch: int, out_ch: int, stride: int = 1):
return nn.Sequential(
ResidualBlock(in_ch, out_ch, stride),
ResidualBlock(out_ch, out_ch, 1)
)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
return self.fc(x)
class DeepSVDDAnomalyDetector:
"""
Deep SVDD Anomaly Detection Model
Usage:
from model import DeepSVDDAnomalyDetector
detector = DeepSVDDAnomalyDetector.from_pretrained('.')
score, is_anomaly = detector.predict('image.jpg')
"""
def __init__(self, model_path, thresholds_path, config_path, device='cuda'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
# Load config
with open(config_path, 'r') as f:
self.config = json.load(f)
# Load model
checkpoint = torch.load(model_path, map_location=self.device)
self.latent_dim = checkpoint['latent_dim']
self.center = checkpoint['center'].to(self.device)
self.radius = checkpoint['radius'].item()
self.encoder = DeepSVDDEncoder(self.latent_dim).to(self.device)
self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
self.encoder.eval()
# Load thresholds
with open(thresholds_path, 'rb') as f:
thresholds = pickle.load(f)
self.threshold_95 = thresholds['95th_percentile']
self.threshold_99 = thresholds['99th_percentile']
self.threshold_optimal = thresholds['optimal_f1']
self.threshold = self.threshold_optimal
# Image preprocessing
self.transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@classmethod
def from_pretrained(cls, model_path='.', device='cuda'):
"""Load pretrained model from directory or HuggingFace Hub"""
model_path = Path(model_path)
return cls(
model_path=model_path / 'deepsvdd_model.pth',
thresholds_path=model_path / 'thresholds.pkl',
config_path=model_path / 'config.json',
device=device
)
def set_threshold(self, threshold_type='optimal'):
"""Set threshold: 'optimal', '95th', or '99th'"""
if threshold_type == 'optimal':
self.threshold = self.threshold_optimal
elif threshold_type == '95th':
self.threshold = self.threshold_95
elif threshold_type == '99th':
self.threshold = self.threshold_99
@torch.no_grad()
def predict(self, image_path):
"""Predict if image is anomaly"""
if isinstance(image_path, (str, Path)):
image = Image.open(image_path).convert('RGB')
else:
image = image_path
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
embeddings = self.encoder(image_tensor)
score = torch.sum((embeddings - self.center) ** 2, dim=1).item()
is_anomaly = score > self.threshold
return score, is_anomaly
|