ano_dect / network /anomaly_detector_model.py
foreversheikh's picture
Upload 12 files
1c4c77a verified
"""This module contains an implementation of anomaly detector for videos."""
from typing import Callable
import torch
from torch import Tensor, nn
class AnomalyDetector(nn.Module):
"""Anomaly detection model for videos."""
def __init__(self, input_dim=4096) -> None:
super().__init__()
self.fc1 = nn.Linear(input_dim, 512)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(0.6)
self.fc2 = nn.Linear(512, 32)
self.dropout2 = nn.Dropout(0.6)
self.fc3 = nn.Linear(32, 1)
self.sig = nn.Sigmoid()
# In the original keras code they use "glorot_normal"
# As I understand, this is the same as xavier normal in Pytorch
nn.init.xavier_normal_(self.fc1.weight)
nn.init.xavier_normal_(self.fc2.weight)
nn.init.xavier_normal_(self.fc3.weight)
@property
def input_dim(self) -> int:
return self.fc1.weight.shape[1]
def forward(self, x: Tensor) -> Tensor: # pylint: disable=arguments-differ
x = self.dropout1(self.relu1(self.fc1(x)))
x = self.dropout2(self.fc2(x))
x = self.sig(self.fc3(x))
return x
def custom_objective(y_pred: Tensor, y_true: Tensor) -> Tensor:
"""Calculate loss function with regularization for anomaly detection.
Args:
y_pred (Tensor): A tensor containing the predictions of the model.
y_true (Tensor): A tensor containing the ground truth.
Returns:
Tensor: A single dimension tensor containing the calculated loss.
"""
# y_pred (batch_size, 32, 1)
# y_true (batch_size)
lambdas = 8e-5
normal_vids_indices = torch.where(y_true == 0)
anomal_vids_indices = torch.where(y_true == 1)
normal_segments_scores = y_pred[normal_vids_indices].squeeze(-1) # (batch/2, 32, 1)
anomal_segments_scores = y_pred[anomal_vids_indices].squeeze(-1) # (batch/2, 32, 1)
# get the max score for each video
normal_segments_scores_maxes = normal_segments_scores.max(dim=-1)[0]
anomal_segments_scores_maxes = anomal_segments_scores.max(dim=-1)[0]
hinge_loss = 1 - anomal_segments_scores_maxes + normal_segments_scores_maxes
hinge_loss = torch.max(hinge_loss, torch.zeros_like(hinge_loss))
# Smoothness of anomalous video
smoothed_scores = anomal_segments_scores[:, 1:] - anomal_segments_scores[:, :-1]
smoothed_scores_sum_squared = smoothed_scores.pow(2).sum(dim=-1)
# Sparsity of anomalous video
sparsity_loss = anomal_segments_scores.sum(dim=-1)
final_loss = (
hinge_loss + lambdas * smoothed_scores_sum_squared + lambdas * sparsity_loss
).mean()
return final_loss
class RegularizedLoss(torch.nn.Module):
"""Regularizes a loss function."""
def __init__(
self,
model: AnomalyDetector,
original_objective: Callable,
lambdas: float = 0.001,
) -> None:
super().__init__()
self.lambdas = lambdas
self.model = model
self.objective = original_objective
def forward(self, y_pred: Tensor, y_true: Tensor): # pylint: disable=arguments-differ
# loss
# Our loss is defined with respect to l2 regularization, as used in the original keras code
fc1_params = torch.cat(tuple([x.view(-1) for x in self.model.fc1.parameters()]))
fc2_params = torch.cat(tuple([x.view(-1) for x in self.model.fc2.parameters()]))
fc3_params = torch.cat(tuple([x.view(-1) for x in self.model.fc3.parameters()]))
l1_regularization = self.lambdas * torch.norm(fc1_params, p=2)
l2_regularization = self.lambdas * torch.norm(fc2_params, p=2)
l3_regularization = self.lambdas * torch.norm(fc3_params, p=2)
return (
self.objective(y_pred, y_true)
+ l1_regularization
+ l2_regularization
+ l3_regularization
)
# ----------------------------------------------------------------------------------------------------------------------
class AnomalyClassifier(nn.Module):
"""
Multi-class anomaly classifier
Supports 13 categories: Normal + 12 anomaly classes
"""
def __init__(self, input_dim=512, num_classes=13):
super(AnomalyClassifier, self).__init__()
self.fc1 = nn.Linear(input_dim, 256)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(256, 64)
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(64, num_classes) # ✅ 13 outputs
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, input_dim) feature vectors
returns: (B, num_classes) logits
"""
x = self.dropout1(self.relu1(self.fc1(x)))
x = self.dropout2(self.relu2(self.fc2(x)))
return self.fc3(x)