ash12321's picture
Upload Deep SVDD anomaly detection model
0e9dcc1 verified
"""
Deep SVDD Anomaly Detection Model
Trained on CIFAR-10, CIFAR-100, and STL-10
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import pickle
import json
from pathlib import Path
class ResidualBlock(nn.Module):
def __init__(self, in_ch: int, out_ch: int, stride: int = 1):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch)
self.shortcut = nn.Sequential()
if stride != 1 or in_ch != out_ch:
self.shortcut = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
nn.BatchNorm2d(out_ch)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return F.relu(out)
class DeepSVDDEncoder(nn.Module):
def __init__(self, latent_dim: int = 512):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(64, 128, stride=2)
self.layer2 = self._make_layer(128, 256, stride=2)
self.layer3 = self._make_layer(256, 512, stride=2)
self.layer4 = self._make_layer(512, 512, stride=2)
self.fc = nn.Linear(512 * 4 * 4, latent_dim, bias=False)
def _make_layer(self, in_ch: int, out_ch: int, stride: int = 1):
return nn.Sequential(
ResidualBlock(in_ch, out_ch, stride),
ResidualBlock(out_ch, out_ch, 1)
)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
return self.fc(x)
class DeepSVDDAnomalyDetector:
"""
Deep SVDD Anomaly Detection Model
Usage:
from model import DeepSVDDAnomalyDetector
detector = DeepSVDDAnomalyDetector.from_pretrained('.')
score, is_anomaly = detector.predict('image.jpg')
"""
def __init__(self, model_path, thresholds_path, config_path, device='cuda'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
# Load config
with open(config_path, 'r') as f:
self.config = json.load(f)
# Load model
checkpoint = torch.load(model_path, map_location=self.device)
self.latent_dim = checkpoint['latent_dim']
self.center = checkpoint['center'].to(self.device)
self.radius = checkpoint['radius'].item()
self.encoder = DeepSVDDEncoder(self.latent_dim).to(self.device)
self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
self.encoder.eval()
# Load thresholds
with open(thresholds_path, 'rb') as f:
thresholds = pickle.load(f)
self.threshold_95 = thresholds['95th_percentile']
self.threshold_99 = thresholds['99th_percentile']
self.threshold_optimal = thresholds['optimal_f1']
self.threshold = self.threshold_optimal
# Image preprocessing
self.transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@classmethod
def from_pretrained(cls, model_path='.', device='cuda'):
"""Load pretrained model from directory or HuggingFace Hub"""
model_path = Path(model_path)
return cls(
model_path=model_path / 'deepsvdd_model.pth',
thresholds_path=model_path / 'thresholds.pkl',
config_path=model_path / 'config.json',
device=device
)
def set_threshold(self, threshold_type='optimal'):
"""Set threshold: 'optimal', '95th', or '99th'"""
if threshold_type == 'optimal':
self.threshold = self.threshold_optimal
elif threshold_type == '95th':
self.threshold = self.threshold_95
elif threshold_type == '99th':
self.threshold = self.threshold_99
@torch.no_grad()
def predict(self, image_path):
"""Predict if image is anomaly"""
if isinstance(image_path, (str, Path)):
image = Image.open(image_path).convert('RGB')
else:
image = image_path
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
embeddings = self.encoder(image_tensor)
score = torch.sum((embeddings - self.center) ** 2, dim=1).item()
is_anomaly = score > self.threshold
return score, is_anomaly