File size: 4,977 Bytes
1c4c77a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""This module contains an implementation of anomaly detector for videos."""

from typing import Callable

import torch
from torch import Tensor, nn


class AnomalyDetector(nn.Module):
    """Anomaly detection model for videos."""

    def __init__(self, input_dim=4096) -> None:
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.6)

        self.fc2 = nn.Linear(512, 32)
        self.dropout2 = nn.Dropout(0.6)

        self.fc3 = nn.Linear(32, 1)
        self.sig = nn.Sigmoid()

        # In the original keras code they use "glorot_normal"
        # As I understand, this is the same as xavier normal in Pytorch
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        nn.init.xavier_normal_(self.fc3.weight)

    @property
    def input_dim(self) -> int:
        return self.fc1.weight.shape[1]

    def forward(self, x: Tensor) -> Tensor:  # pylint: disable=arguments-differ
        x = self.dropout1(self.relu1(self.fc1(x)))
        x = self.dropout2(self.fc2(x))
        x = self.sig(self.fc3(x))
        return x


def custom_objective(y_pred: Tensor, y_true: Tensor) -> Tensor:
    """Calculate loss function with regularization for anomaly detection.



    Args:

        y_pred (Tensor): A tensor containing the predictions of the model.

        y_true (Tensor): A tensor containing the ground truth.



    Returns:

        Tensor: A single dimension tensor containing the calculated loss.

    """
    # y_pred (batch_size, 32, 1)
    # y_true (batch_size)
    lambdas = 8e-5

    normal_vids_indices = torch.where(y_true == 0)
    anomal_vids_indices = torch.where(y_true == 1)

    normal_segments_scores = y_pred[normal_vids_indices].squeeze(-1)  # (batch/2, 32, 1)
    anomal_segments_scores = y_pred[anomal_vids_indices].squeeze(-1)  # (batch/2, 32, 1)

    # get the max score for each video
    normal_segments_scores_maxes = normal_segments_scores.max(dim=-1)[0]
    anomal_segments_scores_maxes = anomal_segments_scores.max(dim=-1)[0]

    hinge_loss = 1 - anomal_segments_scores_maxes + normal_segments_scores_maxes
    hinge_loss = torch.max(hinge_loss, torch.zeros_like(hinge_loss))

    # Smoothness of anomalous video
    smoothed_scores = anomal_segments_scores[:, 1:] - anomal_segments_scores[:, :-1]
    smoothed_scores_sum_squared = smoothed_scores.pow(2).sum(dim=-1)

    # Sparsity of anomalous video
    sparsity_loss = anomal_segments_scores.sum(dim=-1)

    final_loss = (
        hinge_loss + lambdas * smoothed_scores_sum_squared + lambdas * sparsity_loss
    ).mean()
    return final_loss


class RegularizedLoss(torch.nn.Module):
    """Regularizes a loss function."""

    def __init__(

        self,

        model: AnomalyDetector,

        original_objective: Callable,

        lambdas: float = 0.001,

    ) -> None:
        super().__init__()
        self.lambdas = lambdas
        self.model = model
        self.objective = original_objective

    def forward(self, y_pred: Tensor, y_true: Tensor):  # pylint: disable=arguments-differ
        # loss
        # Our loss is defined with respect to l2 regularization, as used in the original keras code
        fc1_params = torch.cat(tuple([x.view(-1) for x in self.model.fc1.parameters()]))
        fc2_params = torch.cat(tuple([x.view(-1) for x in self.model.fc2.parameters()]))
        fc3_params = torch.cat(tuple([x.view(-1) for x in self.model.fc3.parameters()]))

        l1_regularization = self.lambdas * torch.norm(fc1_params, p=2)
        l2_regularization = self.lambdas * torch.norm(fc2_params, p=2)
        l3_regularization = self.lambdas * torch.norm(fc3_params, p=2)

        return (
            self.objective(y_pred, y_true)
            + l1_regularization
            + l2_regularization
            + l3_regularization
        )




# ----------------------------------------------------------------------------------------------------------------------
class AnomalyClassifier(nn.Module):
    """

    Multi-class anomaly classifier

    Supports 13 categories: Normal + 12 anomaly classes

    """

    def __init__(self, input_dim=512, num_classes=13):
        super(AnomalyClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)

        self.fc2 = nn.Linear(256, 64)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)

        self.fc3 = nn.Linear(64, num_classes)  # ✅ 13 outputs

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        x: (B, input_dim) feature vectors

        returns: (B, num_classes) logits

        """
        x = self.dropout1(self.relu1(self.fc1(x)))
        x = self.dropout2(self.relu2(self.fc2(x)))
        return self.fc3(x)