foreversheikh commited on
Commit
1c4c77a
·
verified ·
1 Parent(s): 17ee76b

Upload 12 files

Browse files
network/MFNET.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Author: Yunpeng Chen."""
2
+
3
+ import logging
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class BN_AC_CONV3D(nn.Module):
11
+ def __init__(
12
+ self,
13
+ num_in,
14
+ num_filter,
15
+ kernel=(1, 1, 1),
16
+ pad=(0, 0, 0),
17
+ stride=(1, 1, 1),
18
+ g=1,
19
+ bias=False,
20
+ ):
21
+ super().__init__()
22
+ self.bn = nn.BatchNorm3d(num_in)
23
+ self.relu = nn.ReLU(inplace=True)
24
+ self.conv = nn.Conv3d(
25
+ num_in,
26
+ num_filter,
27
+ kernel_size=kernel,
28
+ padding=pad,
29
+ stride=stride,
30
+ groups=g,
31
+ bias=bias,
32
+ )
33
+
34
+ def forward(self, x):
35
+ h = self.relu(self.bn(x))
36
+ h = self.conv(h)
37
+ return h
38
+
39
+
40
+ class MF_UNIT(nn.Module):
41
+ def __init__(
42
+ self,
43
+ num_in,
44
+ num_mid,
45
+ num_out,
46
+ g=1,
47
+ stride=(1, 1, 1),
48
+ first_block=False,
49
+ use_3d=True,
50
+ ):
51
+ super().__init__()
52
+ num_ix = int(num_mid / 4)
53
+ kt, pt = (3, 1) if use_3d else (1, 0)
54
+ # prepare input
55
+ self.conv_i1 = BN_AC_CONV3D(
56
+ num_in=num_in, num_filter=num_ix, kernel=(1, 1, 1), pad=(0, 0, 0)
57
+ )
58
+ self.conv_i2 = BN_AC_CONV3D(
59
+ num_in=num_ix, num_filter=num_in, kernel=(1, 1, 1), pad=(0, 0, 0)
60
+ )
61
+ # main part
62
+ self.conv_m1 = BN_AC_CONV3D(
63
+ num_in=num_in,
64
+ num_filter=num_mid,
65
+ kernel=(kt, 3, 3),
66
+ pad=(pt, 1, 1),
67
+ stride=stride,
68
+ g=g,
69
+ )
70
+ if first_block:
71
+ self.conv_m2 = BN_AC_CONV3D(
72
+ num_in=num_mid, num_filter=num_out, kernel=(1, 1, 1), pad=(0, 0, 0)
73
+ )
74
+ else:
75
+ self.conv_m2 = BN_AC_CONV3D(
76
+ num_in=num_mid, num_filter=num_out, kernel=(1, 3, 3), pad=(0, 1, 1), g=g
77
+ )
78
+ # adapter
79
+ if first_block:
80
+ self.conv_w1 = BN_AC_CONV3D(
81
+ num_in=num_in,
82
+ num_filter=num_out,
83
+ kernel=(1, 1, 1),
84
+ pad=(0, 0, 0),
85
+ stride=stride,
86
+ )
87
+
88
+ def forward(self, x):
89
+ h = self.conv_i1(x)
90
+ x_in = x + self.conv_i2(h)
91
+
92
+ h = self.conv_m1(x_in)
93
+ h = self.conv_m2(h)
94
+
95
+ if hasattr(self, "conv_w1"):
96
+ x = self.conv_w1(x)
97
+
98
+ return h + x
99
+
100
+
101
+ class MFNET_3D(nn.Module):
102
+ """Original code: https://github.com/cypw/PyTorch-MFNet."""
103
+
104
+ def __init__(
105
+ self,
106
+ **_kwargs,
107
+ ):
108
+ super().__init__()
109
+
110
+ groups = 16
111
+ k_sec = {2: 3, 3: 4, 4: 6, 5: 3}
112
+
113
+ # conv1 - x224 (x16)
114
+ conv1_num_out = 16
115
+ self.conv1 = nn.Sequential(
116
+ OrderedDict(
117
+ [
118
+ (
119
+ "conv",
120
+ nn.Conv3d(
121
+ 3,
122
+ conv1_num_out,
123
+ kernel_size=(3, 5, 5),
124
+ padding=(1, 2, 2),
125
+ stride=(1, 2, 2),
126
+ bias=False,
127
+ ),
128
+ ),
129
+ ("bn", nn.BatchNorm3d(conv1_num_out)),
130
+ ("relu", nn.ReLU(inplace=True)),
131
+ ]
132
+ )
133
+ )
134
+ self.maxpool = nn.MaxPool3d(
135
+ kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)
136
+ )
137
+
138
+ # conv2 - x56 (x8)
139
+ num_mid = 96
140
+ conv2_num_out = 96
141
+ self.conv2 = nn.Sequential(
142
+ OrderedDict(
143
+ [
144
+ (
145
+ "B%02d" % i,
146
+ MF_UNIT(
147
+ num_in=conv1_num_out if i == 1 else conv2_num_out,
148
+ num_mid=num_mid,
149
+ num_out=conv2_num_out,
150
+ stride=(2, 1, 1) if i == 1 else (1, 1, 1),
151
+ g=groups,
152
+ first_block=(i == 1),
153
+ ),
154
+ )
155
+ for i in range(1, k_sec[2] + 1)
156
+ ]
157
+ )
158
+ )
159
+
160
+ # conv3 - x28 (x8)
161
+ num_mid *= 2
162
+ conv3_num_out = 2 * conv2_num_out
163
+ self.conv3 = nn.Sequential(
164
+ OrderedDict(
165
+ [
166
+ (
167
+ "B%02d" % i,
168
+ MF_UNIT(
169
+ num_in=conv2_num_out if i == 1 else conv3_num_out,
170
+ num_mid=num_mid,
171
+ num_out=conv3_num_out,
172
+ stride=(1, 2, 2) if i == 1 else (1, 1, 1),
173
+ g=groups,
174
+ first_block=(i == 1),
175
+ ),
176
+ )
177
+ for i in range(1, k_sec[3] + 1)
178
+ ]
179
+ )
180
+ )
181
+
182
+ # conv4 - x14 (x8)
183
+ num_mid *= 2
184
+ conv4_num_out = 2 * conv3_num_out
185
+ self.conv4 = nn.Sequential(
186
+ OrderedDict(
187
+ [
188
+ (
189
+ "B%02d" % i,
190
+ MF_UNIT(
191
+ num_in=conv3_num_out if i == 1 else conv4_num_out,
192
+ num_mid=num_mid,
193
+ num_out=conv4_num_out,
194
+ stride=(1, 2, 2) if i == 1 else (1, 1, 1),
195
+ g=groups,
196
+ first_block=(i == 1),
197
+ ),
198
+ )
199
+ for i in range(1, k_sec[4] + 1)
200
+ ]
201
+ )
202
+ )
203
+
204
+ # conv5 - x7 (x8)
205
+ num_mid *= 2
206
+ conv5_num_out = 2 * conv4_num_out
207
+ self.conv5 = nn.Sequential(
208
+ OrderedDict(
209
+ [
210
+ (
211
+ "B%02d" % i,
212
+ MF_UNIT(
213
+ num_in=conv4_num_out if i == 1 else conv5_num_out,
214
+ num_mid=num_mid,
215
+ num_out=conv5_num_out,
216
+ stride=(1, 2, 2) if i == 1 else (1, 1, 1),
217
+ g=groups,
218
+ first_block=(i == 1),
219
+ ),
220
+ )
221
+ for i in range(1, k_sec[5] + 1)
222
+ ]
223
+ )
224
+ )
225
+
226
+ # final
227
+ self.tail = nn.Sequential(
228
+ OrderedDict(
229
+ [("bn", nn.BatchNorm3d(conv5_num_out)), ("relu", nn.ReLU(inplace=True))]
230
+ )
231
+ )
232
+
233
+ self.globalpool = nn.Sequential(
234
+ OrderedDict(
235
+ [
236
+ ("avg", nn.AvgPool3d(kernel_size=(1, 7, 7), stride=(1, 1, 1))),
237
+ ("dropout", nn.Dropout(p=0.5)), # only for fine-tuning
238
+ ]
239
+ )
240
+ )
241
+ # self.classifier = nn.Linear(conv5_num_out, num_classes)
242
+
243
+ def forward(self, x):
244
+ # assert x.shape[2] == 16
245
+
246
+ h = self.conv1(x) # x224 -> x112
247
+ h = self.maxpool(h) # x112 -> x56
248
+
249
+ h = self.conv2(h) # x56 -> x56
250
+ h = self.conv3(h) # x56 -> x28
251
+ h = self.conv4(h) # x28 -> x14
252
+ h = self.conv5(h) # x14 -> x7
253
+
254
+ h = self.tail(h)
255
+ h = self.globalpool(h)
256
+
257
+ h = h.view(h.shape[0], -1)
258
+ # h = self.classifier(h)
259
+ # h = h.view(h.shape[0], -1)
260
+ return h
261
+
262
+ def load_state(self, state_dict):
263
+ # customized partialy load function
264
+ checkpoint = torch.load(state_dict, map_location=torch.device("cpu"))
265
+ state_dict = checkpoint["state_dict"]
266
+ net_state_keys = list(self.state_dict().keys())
267
+ for name, param in state_dict.items():
268
+ name = name.replace("module.", "")
269
+ if name in self.state_dict().keys():
270
+ dst_param_shape = self.state_dict()[name].shape
271
+ if param.shape == dst_param_shape:
272
+ self.state_dict()[name].copy_(param.view(dst_param_shape))
273
+ net_state_keys.remove(name)
274
+ # indicating missed keys
275
+ if net_state_keys:
276
+ logging.warning(f">> Failed to load: {net_state_keys}")
277
+
278
+ return self
network/TorchUtils.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Written by Eitan Kosman."""
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ from typing import List, Optional, Union
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+ from torch.optim import Optimizer
11
+ from torch.utils.data import DataLoader
12
+
13
+ from utils.callbacks import Callback
14
+ from utils.types import Device
15
+ import torch
16
+
17
+ from network.anomaly_detector_model import AnomalyDetector
18
+
19
+ # Use safe_globals context
20
+
21
+
22
+
23
+ def get_torch_device() -> Device:
24
+ """
25
+ Retrieves the device to run torch models, with preferability to GPU (denoted as cuda by torch)
26
+ Returns: Device to run the models
27
+ """
28
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+
31
+ def load_model(model_path: str) -> nn.Module:
32
+ """Loads a Pytorch model (CPU compatible, PyTorch >=2.6)."""
33
+ logging.info(f"Load the model from: {model_path}")
34
+
35
+ from network.anomaly_detector_model import AnomalyDetector
36
+
37
+ # Wrap torch.load with safe_globals and weights_only=False
38
+ with torch.serialization.safe_globals([AnomalyDetector]):
39
+ model = torch.load(model_path, map_location="cpu", weights_only=False)
40
+
41
+ logging.info(model)
42
+ return model
43
+
44
+
45
+
46
+ class TorchModel(nn.Module):
47
+ """Wrapper class for a torch model to make it comfortable to train and load
48
+ models."""
49
+
50
+ def __init__(self, model: nn.Module) -> None:
51
+ super().__init__()
52
+ self.device = get_torch_device()
53
+ self.iteration = 0
54
+ self.model = model
55
+ self.is_data_parallel = False
56
+ self.callbacks = []
57
+
58
+ def register_callback(self, callback_fn: Callback) -> None:
59
+ """
60
+ Register a callback to be called after each evaluation run
61
+ Args:
62
+ callback_fn: a callable that accepts 2 inputs (output, target)
63
+ - output is the model's output
64
+ - target is the values of the target variable
65
+ """
66
+ self.callbacks.append(callback_fn)
67
+
68
+ def data_parallel(self):
69
+ """Transfers the model to data parallel mode."""
70
+ self.is_data_parallel = True
71
+ if not isinstance(self.model, torch.nn.DataParallel):
72
+ self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1])
73
+
74
+ return self
75
+
76
+ @classmethod
77
+ def load_model(cls, model_path: str):
78
+ """
79
+ Loads a pickled model
80
+ Args:
81
+ model_path: path to the pickled model
82
+
83
+ Returns: TorchModel class instance wrapping the provided model
84
+ """
85
+ return cls(load_model(model_path))
86
+
87
+ def notify_callbacks(self, notification, *args, **kwargs) -> None:
88
+ """Calls all callbacks registered with this class.
89
+
90
+ Args:
91
+ notification: The type of notification to be called.
92
+ """
93
+ for callback in self.callbacks:
94
+ try:
95
+ method = getattr(callback, notification)
96
+ method(*args, **kwargs)
97
+ except (AttributeError, TypeError) as e:
98
+ logging.error(
99
+ f"callback {callback.__class__.__name__} doesn't fully implement the required interface {e}" # pylint: disable=line-too-long
100
+ )
101
+
102
+ def fit(
103
+ self,
104
+ train_iter: DataLoader,
105
+ criterion: nn.Module,
106
+ optimizer: Optimizer,
107
+ eval_iter: Optional[DataLoader] = None,
108
+ epochs: int = 10,
109
+ network_model_path_base: Optional[str] = None,
110
+ save_every: Optional[int] = None,
111
+ evaluate_every: Optional[int] = None,
112
+ ) -> None:
113
+ """
114
+
115
+ Args:
116
+ train_iter: iterator for training
117
+ criterion: loss function
118
+ optimizer: optimizer for the algorithm
119
+ eval_iter: iterator for evaluation
120
+ epochs: amount of epochs
121
+ network_model_path_base: where to save the models
122
+ save_every: saving model checkpoints every specified amount of epochs
123
+ evaluate_every: perform evaluation every specified amount of epochs.
124
+ If the evaluation is expensive, you probably want to
125
+ choose a high value for this
126
+ """
127
+ criterion = criterion.to(self.device)
128
+ self.notify_callbacks("on_training_start", epochs)
129
+
130
+ for epoch in range(epochs):
131
+ train_loss = self.do_epoch(
132
+ criterion=criterion,
133
+ optimizer=optimizer,
134
+ data_iter=train_iter,
135
+ epoch=epoch,
136
+ )
137
+
138
+ if save_every and network_model_path_base and epoch % save_every == 0:
139
+ logging.info(f"Save the model after epoch {epoch}")
140
+ self.save(os.path.join(network_model_path_base, f"epoch_{epoch}.pt"))
141
+
142
+ val_loss = None
143
+ if eval_iter and evaluate_every and epoch % evaluate_every == 0:
144
+ logging.info(f"Evaluating after epoch {epoch}")
145
+ val_loss = self.evaluate(
146
+ criterion=criterion,
147
+ data_iter=eval_iter,
148
+ )
149
+
150
+ self.notify_callbacks("on_training_iteration_end", train_loss, val_loss)
151
+
152
+ self.notify_callbacks("on_training_end", self.model)
153
+ # Save the last model anyway...
154
+ if network_model_path_base:
155
+ self.save(os.path.join(network_model_path_base, f"epoch_{epoch + 1}.pt"))
156
+
157
+ def evaluate(self, criterion: nn.Module, data_iter: DataLoader) -> float:
158
+ """
159
+ Evaluates the model
160
+ Args:
161
+ criterion: Loss function for calculating the evaluation
162
+ data_iter: torch data iterator
163
+ """
164
+ self.eval()
165
+ self.notify_callbacks("on_evaluation_start", len(data_iter))
166
+ total_loss = 0
167
+
168
+ with torch.no_grad():
169
+ for iteration, (batch, targets) in enumerate(data_iter):
170
+ batch = self.data_to_device(batch, self.device)
171
+ targets = self.data_to_device(targets, self.device)
172
+
173
+ outputs = self.model(batch)
174
+ loss = criterion(outputs, targets)
175
+
176
+ self.notify_callbacks(
177
+ "on_evaluation_step",
178
+ iteration,
179
+ outputs.detach().cpu(),
180
+ targets.detach().cpu(),
181
+ loss.item(),
182
+ )
183
+
184
+ total_loss += loss.item()
185
+
186
+ loss = total_loss / len(data_iter)
187
+ self.notify_callbacks("on_evaluation_end")
188
+ return loss
189
+
190
+ def do_epoch(
191
+ self,
192
+ criterion: nn.Module,
193
+ optimizer: Optimizer,
194
+ data_iter: DataLoader,
195
+ epoch: int,
196
+ ) -> float:
197
+ """Perform a whole epoch.
198
+
199
+ Args:
200
+ criterion (nn.Module): Loss function to be used.
201
+ optimizer (Optimizer): Optimizer to use for minimizing the loss function.
202
+ data_iter (DataLoader): Loader for data samples used for training the model.
203
+ epoch (int): The epoch number.
204
+
205
+ Returns:
206
+ float: Average training loss calculated during the epoch.
207
+ """
208
+ total_loss = 0
209
+ total_time = 0.0
210
+ self.train()
211
+ self.notify_callbacks("on_epoch_start", epoch, len(data_iter))
212
+ for iteration, (batch, targets) in enumerate(data_iter):
213
+ self.iteration += 1
214
+ start_time = time.time()
215
+ batch = self.data_to_device(batch, self.device)
216
+ targets = self.data_to_device(targets, self.device)
217
+
218
+ outputs = self.model(batch)
219
+
220
+ loss = criterion(outputs, targets)
221
+
222
+ # Backward and optimize
223
+ optimizer.zero_grad()
224
+ loss.backward()
225
+ optimizer.step()
226
+
227
+ total_loss += loss.item()
228
+
229
+ end_time = time.time()
230
+
231
+ total_time += end_time - start_time
232
+
233
+ self.notify_callbacks(
234
+ "on_epoch_step",
235
+ self.iteration,
236
+ iteration,
237
+ loss.item(),
238
+ )
239
+ self.iteration += 1
240
+
241
+ loss = total_loss / len(data_iter)
242
+
243
+ self.notify_callbacks("on_epoch_end", loss)
244
+ return loss
245
+
246
+ def data_to_device(
247
+ self, data: Union[Tensor, List[Tensor]], device: Device
248
+ ) -> Union[Tensor, List[Tensor]]:
249
+ """
250
+ Transfers a tensor data to a device
251
+ Args:
252
+ data: torch tensor
253
+ device: target device
254
+ """
255
+ if isinstance(data, list):
256
+ data = [d.to(device) for d in data]
257
+ elif isinstance(data, tuple):
258
+ data = tuple([d.to(device) for d in data])
259
+ else:
260
+ data = data.to(device)
261
+
262
+ return data
263
+
264
+ def save(self, model_path: str) -> None:
265
+ """Saves the model to the given path.
266
+
267
+ If currently using data parallel, the method
268
+ will save the original model and not the data parallel instance of it
269
+ Args:
270
+ model_path: target path to save the model to
271
+ """
272
+ if self.is_data_parallel:
273
+ torch.save(self.model.module, model_path)
274
+ else:
275
+ torch.save(self.model, model_path)
276
+
277
+ def get_model(self) -> nn.Module:
278
+ if self.is_data_parallel:
279
+ return self.model.module
280
+
281
+ return self.model
282
+
283
+ def forward(self, *args, **kwargs):
284
+ return self.model(*args, **kwargs)
network/__init__.py ADDED
File without changes
network/__pycache__/MFNET.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
network/__pycache__/TorchUtils.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
network/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (176 Bytes). View file
 
network/__pycache__/anomaly_detector_model.cpython-311.pyc ADDED
Binary file (9.39 kB). View file
 
network/__pycache__/c3d.cpython-311.pyc ADDED
Binary file (6.81 kB). View file
 
network/__pycache__/resnet.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
network/anomaly_detector_model.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains an implementation of anomaly detector for videos."""
2
+
3
+ from typing import Callable
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ class AnomalyDetector(nn.Module):
10
+ """Anomaly detection model for videos."""
11
+
12
+ def __init__(self, input_dim=4096) -> None:
13
+ super().__init__()
14
+ self.fc1 = nn.Linear(input_dim, 512)
15
+ self.relu1 = nn.ReLU()
16
+ self.dropout1 = nn.Dropout(0.6)
17
+
18
+ self.fc2 = nn.Linear(512, 32)
19
+ self.dropout2 = nn.Dropout(0.6)
20
+
21
+ self.fc3 = nn.Linear(32, 1)
22
+ self.sig = nn.Sigmoid()
23
+
24
+ # In the original keras code they use "glorot_normal"
25
+ # As I understand, this is the same as xavier normal in Pytorch
26
+ nn.init.xavier_normal_(self.fc1.weight)
27
+ nn.init.xavier_normal_(self.fc2.weight)
28
+ nn.init.xavier_normal_(self.fc3.weight)
29
+
30
+ @property
31
+ def input_dim(self) -> int:
32
+ return self.fc1.weight.shape[1]
33
+
34
+ def forward(self, x: Tensor) -> Tensor: # pylint: disable=arguments-differ
35
+ x = self.dropout1(self.relu1(self.fc1(x)))
36
+ x = self.dropout2(self.fc2(x))
37
+ x = self.sig(self.fc3(x))
38
+ return x
39
+
40
+
41
+ def custom_objective(y_pred: Tensor, y_true: Tensor) -> Tensor:
42
+ """Calculate loss function with regularization for anomaly detection.
43
+
44
+ Args:
45
+ y_pred (Tensor): A tensor containing the predictions of the model.
46
+ y_true (Tensor): A tensor containing the ground truth.
47
+
48
+ Returns:
49
+ Tensor: A single dimension tensor containing the calculated loss.
50
+ """
51
+ # y_pred (batch_size, 32, 1)
52
+ # y_true (batch_size)
53
+ lambdas = 8e-5
54
+
55
+ normal_vids_indices = torch.where(y_true == 0)
56
+ anomal_vids_indices = torch.where(y_true == 1)
57
+
58
+ normal_segments_scores = y_pred[normal_vids_indices].squeeze(-1) # (batch/2, 32, 1)
59
+ anomal_segments_scores = y_pred[anomal_vids_indices].squeeze(-1) # (batch/2, 32, 1)
60
+
61
+ # get the max score for each video
62
+ normal_segments_scores_maxes = normal_segments_scores.max(dim=-1)[0]
63
+ anomal_segments_scores_maxes = anomal_segments_scores.max(dim=-1)[0]
64
+
65
+ hinge_loss = 1 - anomal_segments_scores_maxes + normal_segments_scores_maxes
66
+ hinge_loss = torch.max(hinge_loss, torch.zeros_like(hinge_loss))
67
+
68
+ # Smoothness of anomalous video
69
+ smoothed_scores = anomal_segments_scores[:, 1:] - anomal_segments_scores[:, :-1]
70
+ smoothed_scores_sum_squared = smoothed_scores.pow(2).sum(dim=-1)
71
+
72
+ # Sparsity of anomalous video
73
+ sparsity_loss = anomal_segments_scores.sum(dim=-1)
74
+
75
+ final_loss = (
76
+ hinge_loss + lambdas * smoothed_scores_sum_squared + lambdas * sparsity_loss
77
+ ).mean()
78
+ return final_loss
79
+
80
+
81
+ class RegularizedLoss(torch.nn.Module):
82
+ """Regularizes a loss function."""
83
+
84
+ def __init__(
85
+ self,
86
+ model: AnomalyDetector,
87
+ original_objective: Callable,
88
+ lambdas: float = 0.001,
89
+ ) -> None:
90
+ super().__init__()
91
+ self.lambdas = lambdas
92
+ self.model = model
93
+ self.objective = original_objective
94
+
95
+ def forward(self, y_pred: Tensor, y_true: Tensor): # pylint: disable=arguments-differ
96
+ # loss
97
+ # Our loss is defined with respect to l2 regularization, as used in the original keras code
98
+ fc1_params = torch.cat(tuple([x.view(-1) for x in self.model.fc1.parameters()]))
99
+ fc2_params = torch.cat(tuple([x.view(-1) for x in self.model.fc2.parameters()]))
100
+ fc3_params = torch.cat(tuple([x.view(-1) for x in self.model.fc3.parameters()]))
101
+
102
+ l1_regularization = self.lambdas * torch.norm(fc1_params, p=2)
103
+ l2_regularization = self.lambdas * torch.norm(fc2_params, p=2)
104
+ l3_regularization = self.lambdas * torch.norm(fc3_params, p=2)
105
+
106
+ return (
107
+ self.objective(y_pred, y_true)
108
+ + l1_regularization
109
+ + l2_regularization
110
+ + l3_regularization
111
+ )
112
+
113
+
114
+
115
+
116
+ # ----------------------------------------------------------------------------------------------------------------------
117
+ class AnomalyClassifier(nn.Module):
118
+ """
119
+ Multi-class anomaly classifier
120
+ Supports 13 categories: Normal + 12 anomaly classes
121
+ """
122
+
123
+ def __init__(self, input_dim=512, num_classes=13):
124
+ super(AnomalyClassifier, self).__init__()
125
+ self.fc1 = nn.Linear(input_dim, 256)
126
+ self.relu1 = nn.ReLU()
127
+ self.dropout1 = nn.Dropout(0.5)
128
+
129
+ self.fc2 = nn.Linear(256, 64)
130
+ self.relu2 = nn.ReLU()
131
+ self.dropout2 = nn.Dropout(0.5)
132
+
133
+ self.fc3 = nn.Linear(64, num_classes) # ✅ 13 outputs
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ """
137
+ x: (B, input_dim) feature vectors
138
+ returns: (B, num_classes) logits
139
+ """
140
+ x = self.dropout1(self.relu1(self.fc1(x)))
141
+ x = self.dropout2(self.relu2(self.fc2(x)))
142
+ return self.fc3(x)
network/c3d.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ "This module contains an implementation of C3D model for video
2
+ processing."""
3
+
4
+ import itertools
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+
9
+
10
+ class C3D(nn.Module):
11
+ """The C3D network."""
12
+
13
+ def __init__(self, pretrained=None):
14
+ super().__init__()
15
+
16
+ self.pretrained = pretrained
17
+
18
+ self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
19
+ self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
20
+
21
+ self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1))
22
+ self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
23
+
24
+ self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
25
+ self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
26
+ self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
27
+
28
+ self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
29
+ self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
30
+ self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
31
+
32
+ self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
33
+ self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
34
+ self.pool5 = nn.MaxPool3d(
35
+ kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)
36
+ )
37
+
38
+ self.fc6 = nn.Linear(8192, 4096)
39
+ self.relu = nn.ReLU()
40
+ self.__init_weight()
41
+
42
+ if pretrained:
43
+ self.__load_pretrained_weights()
44
+
45
+ def forward(self, x: Tensor):
46
+ x = self.relu(self.conv1(x))
47
+ x = self.pool1(x)
48
+ x = self.relu(self.conv2(x))
49
+ x = self.pool2(x)
50
+ x = self.relu(self.conv3a(x))
51
+ x = self.relu(self.conv3b(x))
52
+ x = self.pool3(x)
53
+ x = self.relu(self.conv4a(x))
54
+ x = self.relu(self.conv4b(x))
55
+ x = self.pool4(x)
56
+ x = self.relu(self.conv5a(x))
57
+ x = self.relu(self.conv5b(x))
58
+ x = self.pool5(x)
59
+ # x = x.view(-1, 8192)
60
+ x = x.view(x.size(0), -1) # changed
61
+ x = self.relu(self.fc6(x))
62
+
63
+ return x
64
+
65
+ def __load_pretrained_weights(self):
66
+ """Initialiaze network."""
67
+ corresp_name = [
68
+ # Conv1
69
+ "conv1.weight",
70
+ "conv1.bias",
71
+ # Conv2
72
+ "conv2.weight",
73
+ "conv2.bias",
74
+ # Conv3a
75
+ "conv3a.weight",
76
+ "conv3a.bias",
77
+ # Conv3b
78
+ "conv3b.weight",
79
+ "conv3b.bias",
80
+ # Conv4a
81
+ "conv4a.weight",
82
+ "conv4a.bias",
83
+ # Conv4b
84
+ "conv4b.weight",
85
+ "conv4b.bias",
86
+ # Conv5a
87
+ "conv5a.weight",
88
+ "conv5a.bias",
89
+ # Conv5b
90
+ "conv5b.weight",
91
+ "conv5b.bias",
92
+ # fc6
93
+ "fc6.weight",
94
+ "fc6.bias",
95
+ ]
96
+
97
+ ignored_weights = [
98
+ f"{layer}.{type_}"
99
+ for layer, type_ in itertools.product(["fc7", "fc8"], ["bias", "weight"])
100
+ ]
101
+
102
+ p_dict = torch.load(self.pretrained)
103
+ s_dict = self.state_dict()
104
+ for name in p_dict:
105
+ if name not in corresp_name:
106
+ if name in ignored_weights:
107
+ continue
108
+ print("no corresponding::", name)
109
+ continue
110
+ s_dict[name] = p_dict[name]
111
+ self.load_state_dict(s_dict)
112
+
113
+ def __init_weight(self):
114
+ """Initialize weights of the model."""
115
+ for m in self.modules():
116
+ if isinstance(m, nn.Conv3d):
117
+ torch.nn.init.kaiming_normal_(m.weight)
118
+ elif isinstance(m, nn.BatchNorm3d):
119
+ m.weight.data.fill_(1)
120
+ m.bias.data.zero_()
121
+
122
+
123
+ if __name__ == "__main__":
124
+ inputs = torch.ones((1, 3, 16, 112, 112))
125
+ net = C3D(pretrained=False)
126
+
127
+ outputs = net.forward(inputs)
128
+ print(outputs.size())
129
+
network/resnet.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ "This module contains an implementation of ResNet model for video
2
+ processing."""
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+
11
+ def get_inplanes():
12
+ return [64, 128, 256, 512]
13
+
14
+
15
+ def conv3x3x3(in_planes, out_planes, stride=1):
16
+ return nn.Conv3d(
17
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
18
+ )
19
+
20
+
21
+ def conv1x1x1(in_planes, out_planes, stride=1):
22
+ return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
23
+
24
+
25
+ class BasicBlock(nn.Module):
26
+ expansion = 1
27
+
28
+ def __init__(self, in_planes, planes, stride=1, downsample=None):
29
+ super().__init__()
30
+
31
+ self.conv1 = conv3x3x3(in_planes, planes, stride)
32
+ self.bn1 = nn.BatchNorm3d(planes)
33
+ self.relu = nn.ReLU(inplace=True)
34
+ self.conv2 = conv3x3x3(planes, planes)
35
+ self.bn2 = nn.BatchNorm3d(planes)
36
+ self.downsample = downsample
37
+ self.stride = stride
38
+
39
+ def forward(self, x):
40
+ residual = x
41
+
42
+ out = self.conv1(x)
43
+ out = self.bn1(out)
44
+ out = self.relu(out)
45
+
46
+ out = self.conv2(out)
47
+ out = self.bn2(out)
48
+
49
+ if self.downsample is not None:
50
+ residual = self.downsample(x)
51
+
52
+ out += residual
53
+ out = self.relu(out)
54
+
55
+ return out
56
+
57
+
58
+ class Bottleneck(nn.Module):
59
+ expansion = 4
60
+
61
+ def __init__(self, in_planes, planes, stride=1, downsample=None):
62
+ super().__init__()
63
+
64
+ self.conv1 = conv1x1x1(in_planes, planes)
65
+ self.bn1 = nn.BatchNorm3d(planes)
66
+ self.conv2 = conv3x3x3(planes, planes, stride)
67
+ self.bn2 = nn.BatchNorm3d(planes)
68
+ self.conv3 = conv1x1x1(planes, planes * self.expansion)
69
+ self.bn3 = nn.BatchNorm3d(planes * self.expansion)
70
+ self.relu = nn.ReLU(inplace=True)
71
+ self.downsample = downsample
72
+ self.stride = stride
73
+
74
+ def forward(self, x):
75
+ residual = x
76
+
77
+ out = self.conv1(x)
78
+ out = self.bn1(out)
79
+ out = self.relu(out)
80
+
81
+ out = self.conv2(out)
82
+ out = self.bn2(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv3(out)
86
+ out = self.bn3(out)
87
+
88
+ if self.downsample is not None:
89
+ residual = self.downsample(x)
90
+
91
+ out += residual
92
+ out = self.relu(out)
93
+
94
+ return out
95
+
96
+
97
+ class ResNet(nn.Module):
98
+ def __init__(
99
+ self,
100
+ block,
101
+ layers,
102
+ block_inplanes,
103
+ n_input_channels=3,
104
+ conv1_t_size=7,
105
+ conv1_t_stride=1,
106
+ no_max_pool=False,
107
+ shortcut_type="B",
108
+ widen_factor=1.0,
109
+ n_classes=1039,
110
+ ):
111
+ super().__init__()
112
+
113
+ block_inplanes = [int(x * widen_factor) for x in block_inplanes]
114
+
115
+ self.in_planes = block_inplanes[0]
116
+ self.no_max_pool = no_max_pool
117
+
118
+ self.conv1 = nn.Conv3d(
119
+ n_input_channels,
120
+ self.in_planes,
121
+ kernel_size=(conv1_t_size, 7, 7),
122
+ stride=(conv1_t_stride, 2, 2),
123
+ padding=(conv1_t_size // 2, 3, 3),
124
+ bias=False,
125
+ )
126
+ self.bn1 = nn.BatchNorm3d(self.in_planes)
127
+ self.relu = nn.ReLU(inplace=True)
128
+ self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
129
+ self.layer1 = self._make_layer(
130
+ block, block_inplanes[0], layers[0], shortcut_type
131
+ )
132
+ self.layer2 = self._make_layer(
133
+ block, block_inplanes[1], layers[1], shortcut_type, stride=2
134
+ )
135
+ self.layer3 = self._make_layer(
136
+ block, block_inplanes[2], layers[2], shortcut_type, stride=2
137
+ )
138
+ self.layer4 = self._make_layer(
139
+ block, block_inplanes[3], layers[3], shortcut_type, stride=2
140
+ )
141
+
142
+ self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
143
+ # self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes)
144
+
145
+ for m in self.modules():
146
+ if isinstance(m, nn.Conv3d):
147
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
148
+ elif isinstance(m, nn.BatchNorm3d):
149
+ nn.init.constant_(m.weight, 1)
150
+ nn.init.constant_(m.bias, 0)
151
+
152
+ def _downsample_basic_block(self, x, planes, stride):
153
+ out = F.avg_pool3d(x, kernel_size=1, stride=stride)
154
+ zero_pads = torch.zeros(
155
+ out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)
156
+ )
157
+ if isinstance(out.data, torch.cuda.FloatTensor):
158
+ zero_pads = zero_pads.cuda()
159
+
160
+ out = torch.cat([out.data, zero_pads], dim=1)
161
+
162
+ return out
163
+
164
+ def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
165
+ downsample = None
166
+ if stride != 1 or self.in_planes != planes * block.expansion:
167
+ if shortcut_type == "A":
168
+ downsample = partial(
169
+ self._downsample_basic_block,
170
+ planes=planes * block.expansion,
171
+ stride=stride,
172
+ )
173
+ else:
174
+ downsample = nn.Sequential(
175
+ conv1x1x1(self.in_planes, planes * block.expansion, stride),
176
+ nn.BatchNorm3d(planes * block.expansion),
177
+ )
178
+
179
+ layers = []
180
+ layers.append(
181
+ block(
182
+ in_planes=self.in_planes,
183
+ planes=planes,
184
+ stride=stride,
185
+ downsample=downsample,
186
+ )
187
+ )
188
+ self.in_planes = planes * block.expansion
189
+ for _ in range(1, blocks):
190
+ layers.append(block(self.in_planes, planes))
191
+
192
+ return nn.Sequential(*layers)
193
+
194
+ def forward(self, x):
195
+ x = self.conv1(x)
196
+ x = self.bn1(x)
197
+ x = self.relu(x)
198
+ if not self.no_max_pool:
199
+ x = self.maxpool(x)
200
+
201
+ x = self.layer1(x)
202
+ x = self.layer2(x)
203
+ x = self.layer3(x)
204
+ x = self.layer4(x)
205
+
206
+ x = self.avgpool(x)
207
+
208
+ x = x.view(x.size(0), -1)
209
+ # x = self.fc(x)
210
+
211
+ return x
212
+
213
+
214
+ def generate_model(model_depth, **kwargs):
215
+ assert model_depth in [10, 18, 34, 50, 101, 152, 200]
216
+
217
+ if model_depth == 10:
218
+ model = ResNet(BasicBlock, [1, 1, 1, 1], get_inplanes(), **kwargs)
219
+ elif model_depth == 18:
220
+ model = ResNet(BasicBlock, [2, 2, 2, 2], get_inplanes(), **kwargs)
221
+ elif model_depth == 34:
222
+ model = ResNet(BasicBlock, [3, 4, 6, 3], get_inplanes(), **kwargs)
223
+ elif model_depth == 50:
224
+ model = ResNet(Bottleneck, [3, 4, 6, 3], get_inplanes(), **kwargs)
225
+ elif model_depth == 101:
226
+ model = ResNet(Bottleneck, [3, 4, 23, 3], get_inplanes(), **kwargs)
227
+ elif model_depth == 152:
228
+ model = ResNet(Bottleneck, [3, 8, 36, 3], get_inplanes(), **kwargs)
229
+ elif model_depth == 200:
230
+ model = ResNet(Bottleneck, [3, 24, 36, 3], get_inplanes(), **kwargs)
231
+
232
+ return model