MSherbinii commited on
Commit
0f5deb2
·
verified ·
1 Parent(s): 6852e98

Add IPAD model implementation

Browse files
IPAD/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # IPAD Model Package
2
+ from .model.video_swin_transformer import VST
3
+ from .model.memory_module import MemModule
4
+
5
+ __all__ = ['VST', 'MemModule']
IPAD/model/VST_block.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint as checkpoint
5
+ import numpy as np
6
+ from timm.models.layers import DropPath, trunc_normal_
7
+
8
+ from functools import reduce, lru_cache
9
+ from operator import mul
10
+ from einops import rearrange
11
+
12
+ import logging
13
+
14
+
15
+ class Mlp(nn.Module):
16
+ """ Multilayer perceptron."""
17
+
18
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19
+ super().__init__()
20
+ out_features = out_features or in_features
21
+ hidden_features = hidden_features or in_features
22
+ self.fc1 = nn.Linear(in_features, hidden_features)
23
+ self.act = act_layer()
24
+ self.fc2 = nn.Linear(hidden_features, out_features)
25
+ self.drop = nn.Dropout(drop)
26
+
27
+ def forward(self, x):
28
+ x = self.fc1(x)
29
+ x = self.act(x)
30
+ x = self.drop(x)
31
+ x = self.fc2(x)
32
+ x = self.drop(x)
33
+ return x
34
+
35
+
36
+ def window_partition(x, window_size):
37
+ """
38
+ Args:
39
+ x: (B, D, H, W, C)
40
+ window_size (tuple[int]): window size
41
+ Returns:
42
+ windows: (B*num_windows, window_size*window_size, C)
43
+ """
44
+ B, D, H, W, C = x.shape
45
+ x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C)
46
+ windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
47
+ return windows
48
+
49
+
50
+ def window_reverse(windows, window_size, B, D, H, W):
51
+ """
52
+ Args:
53
+ windows: (B*num_windows, window_size, window_size, C)
54
+ window_size (tuple[int]): Window size
55
+ H (int): Height of image
56
+ W (int): Width of image
57
+ Returns:
58
+ x: (B, D, H, W, C)
59
+ """
60
+ x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1)
61
+ x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)
62
+ return x
63
+
64
+
65
+
66
+
67
+ def get_window_size(x_size, window_size, shift_size=None):
68
+ use_window_size = list(window_size)
69
+ if shift_size is not None:
70
+ use_shift_size = list(shift_size)
71
+ for i in range(len(x_size)):
72
+ if x_size[i] <= window_size[i]:
73
+ use_window_size[i] = x_size[i]
74
+ if shift_size is not None:
75
+ use_shift_size[i] = 0
76
+
77
+ if shift_size is None:
78
+ return tuple(use_window_size)
79
+ else:
80
+ return tuple(use_window_size), tuple(use_shift_size)
81
+
82
+
83
+ class WindowAttention3D(nn.Module):
84
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
85
+ It supports both of shifted and non-shifted window.
86
+ Args:
87
+ dim (int): Number of input channels.
88
+ window_size (tuple[int]): The temporal length, height and width of the window.
89
+ num_heads (int): Number of attention heads.
90
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
91
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
92
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
93
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
94
+ """
95
+
96
+ def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
97
+
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.window_size = window_size # Wd, Wh, Ww
101
+ self.num_heads = num_heads
102
+ head_dim = dim // num_heads
103
+ self.scale = qk_scale or head_dim ** -0.5
104
+
105
+ # define a parameter table of relative position bias
106
+ self.relative_position_bias_table = nn.Parameter(
107
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
108
+
109
+ # get pair-wise relative position index for each token inside the window
110
+ coords_d = torch.arange(self.window_size[0])
111
+ coords_h = torch.arange(self.window_size[1])
112
+ coords_w = torch.arange(self.window_size[2])
113
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
114
+ coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
115
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
116
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
117
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
118
+ relative_coords[:, :, 1] += self.window_size[1] - 1
119
+ relative_coords[:, :, 2] += self.window_size[2] - 1
120
+
121
+ relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
122
+ relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
123
+ relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
124
+ self.register_buffer("relative_position_index", relative_position_index)
125
+
126
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
127
+ self.attn_drop = nn.Dropout(attn_drop)
128
+ self.proj = nn.Linear(dim, dim)
129
+ self.proj_drop = nn.Dropout(proj_drop)
130
+
131
+ trunc_normal_(self.relative_position_bias_table, std=.02)
132
+ self.softmax = nn.Softmax(dim=-1)
133
+
134
+ def forward(self, x, mask=None):
135
+ """ Forward function.
136
+ Args:
137
+ x: input features with shape of (num_windows*B, N, C)
138
+ mask: (0/-inf) mask with shape of (num_windows, N, N) or None
139
+ """
140
+ B_, N, C = x.shape
141
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
142
+ q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C
143
+
144
+ q = q * self.scale
145
+ attn = q @ k.transpose(-2, -1)
146
+
147
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(
148
+ N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH
149
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww
150
+ attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
151
+
152
+ if mask is not None:
153
+ nW = mask.shape[0]
154
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
155
+ attn = attn.view(-1, self.num_heads, N, N)
156
+ attn = self.softmax(attn)
157
+ else:
158
+ attn = self.softmax(attn)
159
+
160
+ attn = self.attn_drop(attn)
161
+
162
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
163
+ x = self.proj(x)
164
+ x = self.proj_drop(x)
165
+ return x
166
+
167
+
168
+ class SwinTransformerBlock3D(nn.Module):
169
+ """ Swin Transformer Block.
170
+ Args:
171
+ dim (int): Number of input channels.
172
+ num_heads (int): Number of attention heads.
173
+ window_size (tuple[int]): Window size.
174
+ shift_size (tuple[int]): Shift size for SW-MSA.
175
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
176
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
177
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
178
+ drop (float, optional): Dropout rate. Default: 0.0
179
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
180
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
181
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
182
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
183
+ """
184
+
185
+ def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),
186
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
187
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
188
+ super().__init__()
189
+ self.dim = dim
190
+ self.num_heads = num_heads
191
+ self.window_size = window_size
192
+ self.shift_size = shift_size
193
+ self.mlp_ratio = mlp_ratio
194
+ self.use_checkpoint=use_checkpoint
195
+
196
+ assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size"
197
+ assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size"
198
+ assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size"
199
+
200
+ self.norm1 = norm_layer(dim)
201
+ self.attn = WindowAttention3D(
202
+ dim, window_size=self.window_size, num_heads=num_heads,
203
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
204
+
205
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
206
+ self.norm2 = norm_layer(dim)
207
+ mlp_hidden_dim = int(dim * mlp_ratio)
208
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
209
+
210
+ def forward_part1(self, x, mask_matrix):
211
+ B, D, H, W, C = x.shape
212
+ window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size)
213
+
214
+ x = self.norm1(x)
215
+ # pad feature maps to multiples of window size
216
+ pad_l = pad_t = pad_d0 = 0
217
+ pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]
218
+ pad_b = (window_size[1] - H % window_size[1]) % window_size[1]
219
+ pad_r = (window_size[2] - W % window_size[2]) % window_size[2]
220
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
221
+ _, Dp, Hp, Wp, _ = x.shape
222
+ # cyclic shift
223
+ if any(i > 0 for i in shift_size):
224
+ shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
225
+ attn_mask = mask_matrix
226
+ else:
227
+ shifted_x = x
228
+ attn_mask = None
229
+ # partition windows
230
+ x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C
231
+ # W-MSA/SW-MSA
232
+ attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C
233
+ # merge windows
234
+ attn_windows = attn_windows.view(-1, *(window_size+(C,)))
235
+ shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C
236
+ # reverse cyclic shift
237
+ if any(i > 0 for i in shift_size):
238
+ x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
239
+ else:
240
+ x = shifted_x
241
+
242
+ if pad_d1 >0 or pad_r > 0 or pad_b > 0:
243
+ x = x[:, :D, :H, :W, :].contiguous()
244
+ return x
245
+
246
+ def forward_part2(self, x):
247
+ return self.drop_path(self.mlp(self.norm2(x)))
248
+
249
+ def forward(self, x, mask_matrix):
250
+ """ Forward function.
251
+ Args:
252
+ x: Input feature, tensor size (B, D, H, W, C).
253
+ mask_matrix: Attention mask for cyclic shift.
254
+ """
255
+
256
+ shortcut = x
257
+ if self.use_checkpoint:
258
+ x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
259
+ else:
260
+ x = self.forward_part1(x, mask_matrix)
261
+ x = shortcut + self.drop_path(x)
262
+
263
+ if self.use_checkpoint:
264
+ x = x + checkpoint.checkpoint(self.forward_part2, x)
265
+ else:
266
+ x = x + self.forward_part2(x)
267
+
268
+ return x
269
+
270
+
271
+ class PatchMerging(nn.Module):
272
+ """ Patch Merging Layer
273
+ Args:
274
+ dim (int): Number of input channels.
275
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
276
+ """
277
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
281
+ self.norm = norm_layer(4 * dim)
282
+
283
+ def forward(self, x):
284
+ """ Forward function.
285
+ Args:
286
+ x: Input feature, tensor size (B, D, H, W, C).
287
+ """
288
+ B, D, H, W, C = x.shape
289
+
290
+ # padding
291
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
292
+ if pad_input:
293
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
294
+
295
+ x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C
296
+ x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C
297
+ x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C
298
+ x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C
299
+ x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C
300
+
301
+ x = self.norm(x)
302
+ x = self.reduction(x)
303
+
304
+ return x
305
+
306
+
307
+ # cache each stage results
308
+ @lru_cache()
309
+ def compute_mask(D, H, W, window_size, shift_size, device):
310
+ img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1
311
+ cnt = 0
312
+ for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):
313
+ for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):
314
+ for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):
315
+ img_mask[:, d, h, w, :] = cnt
316
+ cnt += 1
317
+ mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1
318
+ mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2]
319
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
320
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
321
+ return attn_mask
322
+
323
+
324
+ class BasicLayer(nn.Module):
325
+ """ A basic Swin Transformer layer for one stage.
326
+ Args:
327
+ dim (int): Number of feature channels
328
+ depth (int): Depths of this stage.
329
+ num_heads (int): Number of attention head.
330
+ window_size (tuple[int]): Local window size. Default: (1,7,7).
331
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
332
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
333
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
334
+ drop (float, optional): Dropout rate. Default: 0.0
335
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
336
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
337
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
338
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
339
+ """
340
+
341
+ def __init__(self,
342
+ dim,
343
+ depth,
344
+ num_heads,
345
+ window_size=(1,7,7),
346
+ mlp_ratio=4.,
347
+ qkv_bias=False,
348
+ qk_scale=None,
349
+ drop=0.,
350
+ attn_drop=0.,
351
+ drop_path=0.,
352
+ norm_layer=nn.LayerNorm,
353
+ downsample=None,
354
+ use_checkpoint=False):
355
+ super().__init__()
356
+ self.window_size = window_size
357
+ self.shift_size = tuple(i // 2 for i in window_size)
358
+ self.depth = depth
359
+ self.use_checkpoint = use_checkpoint
360
+
361
+ # build blocks
362
+ self.blocks = nn.ModuleList([
363
+ SwinTransformerBlock3D(
364
+ dim=dim,
365
+ num_heads=num_heads,
366
+ window_size=window_size,
367
+ shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size,
368
+ mlp_ratio=mlp_ratio,
369
+ qkv_bias=qkv_bias,
370
+ qk_scale=qk_scale,
371
+ drop=drop,
372
+ attn_drop=attn_drop,
373
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
374
+ norm_layer=norm_layer,
375
+ use_checkpoint=use_checkpoint,
376
+ )
377
+ for i in range(depth)])
378
+
379
+ self.downsample = downsample
380
+ if self.downsample is not None:
381
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
382
+
383
+ def forward(self, x):
384
+ """ Forward function.
385
+ Args:
386
+ x: Input feature, tensor size (B, C, D, H, W).
387
+ """
388
+ # calculate attention mask for SW-MSA
389
+ B, C, D, H, W = x.shape
390
+ window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)
391
+ x = rearrange(x, 'b c d h w -> b d h w c')
392
+ Dp = int(np.ceil(D / window_size[0])) * window_size[0]
393
+ Hp = int(np.ceil(H / window_size[1])) * window_size[1]
394
+ Wp = int(np.ceil(W / window_size[2])) * window_size[2]
395
+ attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
396
+ for blk in self.blocks:
397
+ x = blk(x, attn_mask)
398
+ x = x.view(B, D, H, W, -1)
399
+
400
+ if self.downsample is not None:
401
+ x = self.downsample(x)
402
+ x = rearrange(x, 'b d h w c -> b c d h w')
403
+ return x
404
+
405
+
406
+ class PatchEmbed3D(nn.Module):
407
+ """ Video to Patch Embedding.
408
+ Args:
409
+ patch_size (int): Patch token size. Default: (2,4,4).
410
+ in_chans (int): Number of input video channels. Default: 3.
411
+ embed_dim (int): Number of linear projection output channels. Default: 96.
412
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
413
+ """
414
+ def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
415
+ super().__init__()
416
+ self.patch_size = patch_size
417
+
418
+ self.in_chans = in_chans
419
+ self.embed_dim = embed_dim
420
+
421
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
422
+ if norm_layer is not None:
423
+ self.norm = norm_layer(embed_dim)
424
+ else:
425
+ self.norm = None
426
+
427
+ def forward(self, x):
428
+ """Forward function."""
429
+ # padding
430
+ _, _, D, H, W = x.size()
431
+ if W % self.patch_size[2] != 0:
432
+ x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
433
+ if H % self.patch_size[1] != 0:
434
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
435
+ if D % self.patch_size[0] != 0:
436
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
437
+
438
+ x = self.proj(x) # B C D Wh Ww
439
+ if self.norm is not None:
440
+ D, Wh, Ww = x.size(2), x.size(3), x.size(4)
441
+ x = x.flatten(2).transpose(1, 2)
442
+ x = self.norm(x)
443
+ x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
444
+
445
+ return x
446
+
447
+
448
+ class SwinTransformer3D(nn.Module):
449
+ """ Swin Transformer backbone.
450
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
451
+ https://arxiv.org/pdf/2103.14030
452
+ Args:
453
+ patch_size (int | tuple(int)): Patch size. Default: (4,4,4).
454
+ in_chans (int): Number of input image channels. Default: 3.
455
+ embed_dim (int): Number of linear projection output channels. Default: 96.
456
+ depths (tuple[int]): Depths of each Swin Transformer stage.
457
+ num_heads (tuple[int]): Number of attention head of each stage.
458
+ window_size (int): Window size. Default: 7.
459
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
460
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
461
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
462
+ drop_rate (float): Dropout rate.
463
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
464
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
465
+ norm_layer: Normalization layer. Default: nn.LayerNorm.
466
+ patch_norm (bool): If True, add normalization after patch embedding. Default: False.
467
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
468
+ -1 means not freezing any parameters.
469
+ """
470
+
471
+ def __init__(self,
472
+ pretrained=None,
473
+ pretrained2d=True,
474
+ patch_size=(4,4,4),
475
+ in_chans=3,
476
+ embed_dim=96,
477
+ depths=[2, 2, 6, 2],
478
+ num_heads=[3, 6, 12, 24],
479
+ window_size=(2,7,7),
480
+ mlp_ratio=4.,
481
+ qkv_bias=True,
482
+ qk_scale=None,
483
+ drop_rate=0.,
484
+ attn_drop_rate=0.,
485
+ drop_path_rate=0.2,
486
+ norm_layer=nn.LayerNorm,
487
+ patch_norm=False,
488
+ frozen_stages=-1,
489
+ use_checkpoint=False):
490
+ super().__init__()
491
+
492
+ self.pretrained = pretrained
493
+ self.pretrained2d = pretrained2d
494
+ self.num_layers = len(depths)
495
+ self.embed_dim = embed_dim
496
+ self.patch_norm = patch_norm
497
+ self.frozen_stages = frozen_stages
498
+ self.window_size = window_size
499
+ self.patch_size = patch_size
500
+
501
+ # split image into non-overlapping patches
502
+ self.patch_embed = PatchEmbed3D(
503
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
504
+ norm_layer=norm_layer if self.patch_norm else None)
505
+
506
+ self.pos_drop = nn.Dropout(p=drop_rate)
507
+
508
+ # stochastic depth
509
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
510
+
511
+ # build layers
512
+ self.layers = nn.ModuleList()
513
+ for i_layer in range(self.num_layers):
514
+ layer = BasicLayer(
515
+ dim=int(embed_dim * 2**i_layer),
516
+ depth=depths[i_layer],
517
+ num_heads=num_heads[i_layer],
518
+ window_size=window_size,
519
+ mlp_ratio=mlp_ratio,
520
+ qkv_bias=qkv_bias,
521
+ qk_scale=qk_scale,
522
+ drop=drop_rate,
523
+ attn_drop=attn_drop_rate,
524
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
525
+ norm_layer=norm_layer,
526
+ downsample=PatchMerging if i_layer<self.num_layers-1 else None,
527
+ use_checkpoint=use_checkpoint)
528
+ self.layers.append(layer)
529
+
530
+ self.num_features = int(embed_dim * 2**(self.num_layers-1))
531
+
532
+ # add a norm layer for each output
533
+ self.norm = norm_layer(self.num_features)
534
+
535
+ self._freeze_stages()
536
+
537
+ def _freeze_stages(self):
538
+ if self.frozen_stages >= 0:
539
+ self.patch_embed.eval()
540
+ for param in self.patch_embed.parameters():
541
+ param.requires_grad = False
542
+
543
+ if self.frozen_stages >= 1:
544
+ self.pos_drop.eval()
545
+ for i in range(0, self.frozen_stages):
546
+ m = self.layers[i]
547
+ m.eval()
548
+ for param in m.parameters():
549
+ param.requires_grad = False
550
+
551
+ def forward(self, x):
552
+ """Forward function."""
553
+ x = self.patch_embed(x)
554
+
555
+ x = self.pos_drop(x)
556
+
557
+ for layer in self.layers:
558
+
559
+ x = layer(x.contiguous())
560
+ x = rearrange(x, 'n c d h w -> n d h w c')
561
+ x = self.norm(x)
562
+ x = rearrange(x, 'n d h w c -> n c d h w')
563
+ return x
564
+
565
+ # def train(self, mode=True):
566
+ # """Convert the model into training mode while keep layers freezed."""
567
+ # super(SwinTransformer3D, self).train(mode)
568
+ # self._freeze_stages()
IPAD/model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from __future__ import absolute_import, print_function
2
+
3
+ from model.memory_module import *
4
+ from model.memae_3dconv import *
5
+ from model.entropy_loss import *
IPAD/model/autoencoder.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .reconstruction_model import Reconstruction3DEncoder, Reconstruction3DDecoder
3
+
4
+ class convAE(torch.nn.Module):
5
+ def __init__(self): # for reconstruction
6
+ super(convAE, self).__init__()
7
+
8
+ self.reconstruction = True
9
+
10
+ # self.encoder = Reconstruction3DEncoder(chnum_in=1) # black and white
11
+ # self.decoder = Reconstruction3DDecoder(chnum_in=1) # black and white
12
+ self.encoder = Reconstruction3DEncoder(chnum_in=3) # RGB
13
+ self.decoder = Reconstruction3DDecoder(chnum_in=3) # RGB
14
+
15
+ def forward(self, x):
16
+ # print(x.shape)
17
+ fea = self.encoder(x)
18
+ # print(fea.shape)
19
+ output = self.decoder(fea.clone())
20
+ # print(output.shape)
21
+
22
+ return output
23
+
24
+
IPAD/model/entropy_loss.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, print_function
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ def feature_map_permute(input):
7
+ s = input.data.shape
8
+ l = len(s)
9
+
10
+ # permute feature channel to the last:
11
+ # NxCxDxHxW --> NxDxHxW x C
12
+ if l == 2:
13
+ x = input # NxC
14
+ elif l == 3:
15
+ x = input.permute(0, 2, 1)
16
+ elif l == 4:
17
+ x = input.permute(0, 2, 3, 1)
18
+ elif l == 5:
19
+ x = input.permute(0, 2, 3, 4, 1)
20
+ else:
21
+ x = []
22
+ print('wrong feature map size')
23
+ x = x.contiguous()
24
+ # NxDxHxW x C --> (NxDxHxW) x C
25
+ x = x.view(-1, s[1])
26
+ return x
27
+
28
+ class EntropyLoss(nn.Module):
29
+ def __init__(self, eps = 1e-12):
30
+ super(EntropyLoss, self).__init__()
31
+ self.eps = eps
32
+
33
+ def forward(self, x):
34
+ b = x * torch.log(x + self.eps)
35
+ b = -1.0 * b.sum(dim=1)
36
+ b = b.mean()
37
+ return b
38
+
39
+ class EntropyLossEncap(nn.Module):
40
+ def __init__(self, eps = 1e-12):
41
+ super(EntropyLossEncap, self).__init__()
42
+ self.eps = eps
43
+ self.entropy_loss = EntropyLoss(eps)
44
+
45
+ def forward(self, input):
46
+ score = feature_map_permute(input)
47
+ ent_loss_val = self.entropy_loss(score)
48
+ return ent_loss_val
49
+
50
+
IPAD/model/memae_3dconv.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, print_function
2
+ import torch
3
+ from torch import nn
4
+
5
+ from model import MemModule
6
+
7
+ class AutoEncoderCov3DMem(nn.Module):
8
+ def __init__(self, chnum_in, mem_dim, shrink_thres=0.0025):
9
+ super(AutoEncoderCov3DMem, self).__init__()
10
+ print('AutoEncoderCov3DMem')
11
+ self.chnum_in = chnum_in
12
+ feature_num = 128
13
+ feature_num_2 = 96
14
+ feature_num_x2 = 256
15
+ self.encoder = nn.Sequential(
16
+ nn.Conv3d(self.chnum_in, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
17
+ nn.BatchNorm3d(feature_num_2),
18
+ nn.LeakyReLU(0.2, inplace=True),
19
+ nn.Conv3d(feature_num_2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
20
+ nn.BatchNorm3d(feature_num),
21
+ nn.LeakyReLU(0.2, inplace=True),
22
+ nn.Conv3d(feature_num, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
23
+ nn.BatchNorm3d(feature_num_x2),
24
+ nn.LeakyReLU(0.2, inplace=True),
25
+ nn.Conv3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
26
+ nn.BatchNorm3d(feature_num_x2),
27
+ nn.LeakyReLU(0.2, inplace=True)
28
+ )
29
+ self.mem_rep = MemModule(mem_dim=mem_dim, fea_dim=feature_num_x2, shrink_thres =shrink_thres)
30
+ self.decoder = nn.Sequential(
31
+ nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
32
+ output_padding=(1, 1, 1)),
33
+ nn.BatchNorm3d(feature_num_x2),
34
+ nn.LeakyReLU(0.2, inplace=True),
35
+ nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
36
+ output_padding=(1, 1, 1)),
37
+ nn.BatchNorm3d(feature_num),
38
+ nn.LeakyReLU(0.2, inplace=True),
39
+ nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
40
+ output_padding=(1, 1, 1)),
41
+ nn.BatchNorm3d(feature_num_2),
42
+ nn.LeakyReLU(0.2, inplace=True),
43
+ nn.ConvTranspose3d(feature_num_2, self.chnum_in, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
44
+ output_padding=(0, 1, 1))
45
+ )
46
+
47
+ def forward(self, x):
48
+ f = self.encoder(x)
49
+ res_mem = self.mem_rep(f)
50
+ f = res_mem['output']
51
+ att = res_mem['att']
52
+ output = self.decoder(f)
53
+ return {'output': output, 'att': att}
IPAD/model/memory_module.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, print_function
2
+ import torch
3
+ from torch import nn
4
+ import math
5
+ from torch.nn.parameter import Parameter
6
+ from torch.nn import functional as F
7
+ import numpy as np
8
+
9
+ #
10
+ class MemoryUnit(nn.Module):
11
+ def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025):
12
+ super(MemoryUnit, self).__init__()
13
+ self.mem_dim = mem_dim
14
+ self.fea_dim = fea_dim
15
+ self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim)) # M x C
16
+ self.bias = None
17
+ self.shrink_thres= shrink_thres
18
+ # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres)
19
+
20
+ self.reset_parameters()
21
+
22
+ def reset_parameters(self):
23
+ stdv = 1. / math.sqrt(self.weight.size(1))
24
+ self.weight.data.uniform_(-stdv, stdv)
25
+ if self.bias is not None:
26
+ self.bias.data.uniform_(-stdv, stdv)
27
+
28
+ def forward(self, input, period_score):
29
+ # print(input.shape)
30
+ score,indices = torch.max(period_score,dim=1)
31
+ indices = (torch.floor((indices/126)*self.mem_dim).cpu().numpy()).astype(int)
32
+ # # print(indices)
33
+ att_weight = F.linear(input, self.weight) # Fea x Mem^T, (TxC) x (CxM) = TxM
34
+ a = score[i]
35
+ att_weight[:,indices[i]-7:indices[i]+8]=att_weight[:,indices[i]-7:indices[i]+8]+att_weight[:,indices[i]-7:indices[i]+8].clone()*score[i]
36
+ att_weight = F.softmax(att_weight, dim=1) # TxM
37
+ # print(att_weight.shape)
38
+ # print(period_score.shape)
39
+ # ReLU based shrinkage, hard shrinkage for positive value
40
+ if(self.shrink_thres>0):
41
+ att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres)
42
+ # att_weight = F.softshrink(att_weight, lambd=self.shrink_thres)
43
+ # normalize???
44
+ att_weight = F.normalize(att_weight, p=1, dim=1)
45
+ # att_weight = F.softmax(att_weight, dim=1)
46
+ # att_weight = self.hard_sparse_shrink_opt(att_weight)
47
+
48
+ mem_trans = self.weight.permute(1, 0) # Mem^T, MxC
49
+ output = F.linear(att_weight, mem_trans) # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC
50
+ return {'output': output, 'att': att_weight} # output, att_weight
51
+
52
+ def extra_repr(self):
53
+ return 'mem_dim={}, fea_dim={}'.format(
54
+ self.mem_dim, self.fea_dim is not None
55
+ )
56
+
57
+
58
+ # NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW
59
+ class MemModule(nn.Module):
60
+ def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025, device='cuda'):
61
+ super(MemModule, self).__init__()
62
+ self.mem_dim = mem_dim
63
+ self.fea_dim = fea_dim
64
+ self.shrink_thres = shrink_thres
65
+ self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)
66
+
67
+ def forward(self, input, period_score):
68
+ s = input.data.shape
69
+ l = len(s)# 5
70
+ if l == 3:
71
+ x = input.permute(0, 2, 1)
72
+ elif l == 4:
73
+ x = input.permute(0, 2, 3, 1)
74
+ elif l == 5:
75
+ x = input.permute(0, 2, 3, 4, 1)
76
+ else:
77
+ x = []
78
+ print('wrong feature map size')
79
+ x = x.contiguous()
80
+ x = x.view(-1, s[1])
81
+ #
82
+ y_and = self.memory(x,period_score)
83
+ #
84
+ y = y_and['output']
85
+ att = y_and['att']
86
+
87
+ if l == 3:
88
+ y = y.view(s[0], s[2], s[1])
89
+ y = y.permute(0, 2, 1)
90
+ att = att.view(s[0], s[2], self.mem_dim)
91
+ att = att.permute(0, 2, 1)
92
+ elif l == 4:
93
+ y = y.view(s[0], s[2], s[3], s[1])
94
+ y = y.permute(0, 3, 1, 2)
95
+ att = att.view(s[0], s[2], s[3], self.mem_dim)
96
+ att = att.permute(0, 3, 1, 2)
97
+ elif l == 5:
98
+ y = y.view(s[0], s[2], s[3], s[4], s[1])
99
+ y = y.permute(0, 4, 1, 2, 3)
100
+ att = att.view(s[0], s[2], s[3], s[4], self.mem_dim)
101
+ att = att.permute(0, 4, 1, 2, 3)
102
+ else:
103
+ y = x
104
+ att = att
105
+ print('wrong feature map size')
106
+ return {'output': y, 'att': att}
107
+
108
+ # relu based hard shrinkage function, only works for positive values
109
+ def hard_shrink_relu(input, lambd=0, epsilon=1e-12):
110
+ output = (F.relu(input-lambd) * input) / (torch.abs(input - lambd) + epsilon)
111
+ return output
112
+
IPAD/model/pseudoanomaly_utils.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ import copy
7
+
8
+ def create_pseudoanomaly_cifar_smooth(img, cifar_img, max_size, h, w, dataset, max_move=0):
9
+ assert 0 <= max_size <= 1
10
+
11
+ pil_img = transforms.ToPILImage()(cifar_img)
12
+ pil_img = transforms.Grayscale(num_output_channels=1)(pil_img)
13
+ cifar_img = transforms.ToTensor()(pil_img)
14
+
15
+ cifar_img = transforms.Normalize(mean=[0.5], std=[0.5])(cifar_img)
16
+
17
+ cifar_patch = F.interpolate(cifar_img.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True)
18
+
19
+ x_mu, y_mu = random.randint(0, w), random.randint(0, h)
20
+ x_sigma = max(10, int(np.random.uniform(high=max_size) * w))
21
+ y_sigma = max(10, int(np.random.uniform(high=max_size) * h))
22
+ if max_move == 0:
23
+ mask = torch.tensor(_get_gaussian_mask(x_mu, y_mu, x_sigma, y_sigma, h, w)).to(img.device).float()
24
+ img = mask * cifar_patch.to(img.device) + (1-mask) * img
25
+ else:
26
+ mask = []
27
+ for frame_idx in range(img.size(1)):
28
+ delta_x = np.random.randint(-max_move, max_move + 1)
29
+ delta_y = np.random.randint(-max_move, max_move + 1)
30
+ mask_ = torch.tensor(_get_gaussian_mask(x_mu, y_mu, x_sigma, y_sigma, h, w)).to(img.device).float()
31
+
32
+ img[:, frame_idx] = mask_ * cifar_patch.to(img.device) + (1-mask_) * img[:, frame_idx]
33
+ mask.append(mask_)
34
+
35
+ x_mu = min(max(0, x_mu + delta_x), w)
36
+ y_mu = min(max(0, y_mu + delta_y), h)
37
+
38
+ mask = torch.stack(mask, dim=0)
39
+
40
+ return img, mask
41
+
42
+
43
+ def create_pseudoanomaly_cifar_smoothborder(img, cifar_img, max_size, h, w, dataset, max_move=0):
44
+ assert 0 <= max_size <= 1
45
+
46
+ pil_img = transforms.ToPILImage()(cifar_img)
47
+ pil_img = transforms.Grayscale(num_output_channels=1)(pil_img)
48
+ cifar_img = transforms.ToTensor()(pil_img)
49
+
50
+
51
+ cifar_img = transforms.Normalize(mean=[0.5], std=[0.5])(cifar_img)
52
+
53
+ cifar_patch = F.interpolate(cifar_img.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True)
54
+
55
+ cx, cy = np.random.randint(w), np.random.randint(h)
56
+
57
+ cut_w= max(10, int(np.random.uniform(high=max_size) * w))
58
+ cut_h = max(10, int(np.random.uniform(high=max_size) * h))
59
+ if max_move == 0:
60
+ mask = torch.tensor(_get_smoothborder_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
61
+ img = mask * cifar_patch.to(img.device) + (1-mask) * img
62
+
63
+ else:
64
+ mask = []
65
+ for frame_idx in range(img.size(1)):
66
+ delta_x = np.random.randint(-max_move, max_move + 1)
67
+ delta_y = np.random.randint(-max_move, max_move + 1)
68
+ mask_ = torch.tensor(_get_smoothborder_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
69
+
70
+ img[:, frame_idx] = mask_ * cifar_patch.to(img.device) + (1 - mask_) * img[:, frame_idx]
71
+ mask.append(mask_)
72
+
73
+ cx = min(max(0, cx + delta_x), w)
74
+ cy = min(max(0, cy + delta_y), h)
75
+
76
+ mask = torch.stack(mask, dim=0)
77
+
78
+ return img, mask
79
+
80
+
81
+
82
+ def create_pseudoanomaly_cifar_cutmix(img, cifar_img, max_size, h, w, dataset, max_move=0):
83
+ assert 0 <= max_size <= 1
84
+
85
+ pil_img = transforms.ToPILImage()(cifar_img)
86
+ pil_img = transforms.Grayscale(num_output_channels=1)(pil_img)
87
+ cifar_img = transforms.ToTensor()(pil_img)
88
+
89
+ cifar_img = transforms.Normalize(mean=[0.5], std=[0.5])(cifar_img)
90
+
91
+ cifar_patch = F.interpolate(cifar_img.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True)
92
+
93
+ cx, cy = np.random.randint(w), np.random.randint(h)
94
+
95
+ cut_w= max(10, int(np.random.uniform(high=max_size) * w))
96
+ cut_h = max(10, int(np.random.uniform(high=max_size) * h))
97
+ if max_move == 0:
98
+ mask = torch.tensor(_get_cutmix_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
99
+ img = mask * cifar_patch.to(img.device) + (1-mask) * img
100
+
101
+ else:
102
+ mask = []
103
+ for frame_idx in range(img.size(1)):
104
+ delta_x = np.random.randint(-max_move, max_move + 1)
105
+ delta_y = np.random.randint(-max_move, max_move + 1)
106
+ mask_ = torch.tensor(_get_cutmix_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
107
+
108
+ img[:, frame_idx] = mask_ * cifar_patch.to(img.device) + (1 - mask_) * img[:, frame_idx]
109
+ mask.append(mask_)
110
+
111
+ cx = min(max(0, cx + delta_x), w)
112
+ cy = min(max(0, cy + delta_y), h)
113
+
114
+ mask = torch.stack(mask, dim=0)
115
+
116
+ return img, mask
117
+
118
+
119
+
120
+ def create_pseudoanomaly_cifar_mixupcutmix(img, cifar_img, max_size, h, w, dataset, max_move=0):
121
+ assert 0 <= max_size <= 1
122
+
123
+ pil_img = transforms.ToPILImage()(cifar_img)
124
+ pil_img = transforms.Grayscale(num_output_channels=1)(pil_img)
125
+ cifar_img = transforms.ToTensor()(pil_img)
126
+
127
+ cifar_img = transforms.Normalize(mean=[0.5], std=[0.5])(cifar_img)
128
+
129
+ cifar_patch = F.interpolate(cifar_img.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True)
130
+
131
+ cx, cy = np.random.randint(w), np.random.randint(h)
132
+
133
+ cut_w= max(10, int(np.random.uniform(high=max_size) * w))
134
+ cut_h = max(10, int(np.random.uniform(high=max_size) * h))
135
+ if max_move == 0:
136
+ mask = torch.tensor(_get_cutmix_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
137
+ img = mask * 0.5 * cifar_patch.to(img.device) + mask * 0.5 * img + (1-mask) * img
138
+
139
+ else:
140
+ mask = []
141
+ for frame_idx in range(img.size(1)):
142
+ delta_x = np.random.randint(-max_move, max_move + 1)
143
+ delta_y = np.random.randint(-max_move, max_move + 1)
144
+ mask_ = torch.tensor(_get_cutmix_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
145
+
146
+ img[:, frame_idx] = mask_ * 0.5 * cifar_patch.to(img.device) + mask_ * 0.5 * img[:, frame_idx] + (1 - mask_) * img[:, frame_idx]
147
+ mask.append(mask_)
148
+
149
+ cx = min(max(0, cx + delta_x), w)
150
+ cy = min(max(0, cy + delta_y), h)
151
+
152
+ mask = torch.stack(mask, dim=0)
153
+
154
+ return img, mask
155
+
156
+
157
+
158
+ def create_pseudoanomaly_seq_smoothborder(img, seq, max_size, h, w, dataset, max_move=0):
159
+ assert 0 <= max_size <= 1
160
+
161
+ cx, cy = np.random.randint(w), np.random.randint(h)
162
+
163
+ cut_w= max(10, int(np.random.uniform(high=max_size) * w))
164
+ cut_h = max(10, int(np.random.uniform(high=max_size) * h))
165
+ if max_move == 0:
166
+ mask = torch.tensor(_get_smoothborder_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
167
+ img = mask * seq.to(img.device) + (1-mask) * img
168
+ else:
169
+ mask = []
170
+ for frame_idx in range(img.size(1)):
171
+ delta_x = np.random.randint(-max_move, max_move + 1)
172
+ delta_y = np.random.randint(-max_move, max_move + 1)
173
+ mask_ = torch.tensor(_get_smoothborder_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
174
+
175
+ img[:, frame_idx] = mask_ * seq[:, frame_idx].to(img.device) + (1 - mask_) * img[:, frame_idx]
176
+ mask.append(mask_)
177
+
178
+ cx = min(max(0, cx + delta_x), w)
179
+ cy = min(max(0, cy + delta_y), h)
180
+
181
+ mask = torch.stack(mask, dim=0)
182
+
183
+ return img, mask
184
+
185
+
186
+
187
+
188
+
189
+ def _get_gaussian_mask(x_mu, y_mu, x_sigma, y_sigma, h, w):
190
+ x, y = np.arange(w), np.arange(h)
191
+
192
+ # x_mu, y_mu = random.randint(0, w), random.randint(0, h)
193
+ # x_sigma = max(10, int(np.random.uniform(high=max_size) * w))
194
+ # y_sigma = max(10, int(np.random.uniform(high=max_size) * h))
195
+
196
+ gx = np.exp(-(x - x_mu) ** 2 / (2 * x_sigma ** 2))
197
+ gy = np.exp(-(y - y_mu) ** 2 / (2 * y_sigma ** 2))
198
+ g = np.outer(gx, gy)
199
+ # g /= np.sum(g) # normalize, if you want that
200
+
201
+ # sum_g = np.sum(g)
202
+ # lam = sum_g / (w * h)
203
+ # print(lam)
204
+
205
+ # plt.imshow(g, interpolation="nearest", origin="lower")
206
+ # plt.show()
207
+ # g = np.dstack([g, g, g])
208
+
209
+ return g
210
+
211
+ # a = _get_gaussian_mask(0.5, 256, 256)
212
+
213
+ def _get_smoothborder_mask(cx, cy, Cut_h, Cut_w, h, w):
214
+ lam = np.random.beta(1, 1)
215
+ percentage = 0.1
216
+ cut_rat = np.sqrt(1. - lam)
217
+
218
+ # Cut_w = min(np.int(max_size*w), max(2, np.int(w * cut_rat)))
219
+ # Cut_h = min(np.int(max_size*h), max(2, np.int(h * cut_rat)))
220
+
221
+ # cx, cy = np.random.randint(w), np.random.randint(h)
222
+
223
+ bbx1 = np.clip(cx - Cut_w // 2, 0, w) # top left x
224
+ bby1 = np.clip(cy - Cut_h // 2, 0, h) # top left y
225
+ bbx2 = np.clip(cx + Cut_w // 2, 0, w) # bottom right x
226
+ bby2 = np.clip(cy + Cut_h // 2, 0, h) # bottom right y
227
+
228
+ img = np.zeros((w, h))
229
+ img2, img3 = np.ones_like(img), np.ones_like(img)
230
+ img[bbx1:bbx2, bby1:bby2] = img2[bbx1:bbx2, bby1:bby2]
231
+
232
+ lo = bbx1 - (Cut_w // 2) * percentage # left side: beginning linear from 0
233
+ li = bbx1 # + (Cut_w // 2) * percentage # left side: end of linear at 1
234
+ ri = bbx2 # - (Cut_w // 2) * percentage # right : start linear from 1
235
+ ro = bbx2 + (Cut_w // 2) * percentage # right: end linear at 0
236
+
237
+ to = bby1 - (Cut_h // 2) * percentage # top: start linear from 0
238
+ ti = bby1 # + (Cut_h // 2) * percentage # top: end linear at 1
239
+ bi = bby2 # - (Cut_h // 2) * percentage # bottom: start linear from 1
240
+ bo = bby2 + (Cut_h // 2) * percentage # bottom: end linear at 0
241
+
242
+ # glx = lambda x: ((x - lo) / (li - lo))
243
+ # grx = lambda x: (-(x - ro) / (ro - ri))
244
+ # gtx = lambda x: ((x - to) / (ti - to))
245
+ # gbx = lambda x: (-(x - bo) / (bo - bi))
246
+
247
+ for i in range(w):
248
+ for j in range(h):
249
+ if i < cx:
250
+ img2[j][i] = ((i - lo) / (li - lo)) # linear going up
251
+ else:
252
+ img2[j][i] = (-(i - ro) / (ro - ri)) # linear going down
253
+ if j < cy:
254
+ img3[j][i] = ((j - to) / (ti - to))
255
+ else:
256
+ img3[j][i] = (-(j - bo) / (bo - bi))
257
+
258
+ img2[img2 < 0] = 0
259
+ img2[img2 > 1] = 1
260
+
261
+ img3[img3 < 0] = 0
262
+ img3[img3 > 1] = 1
263
+
264
+ # plt.figure()
265
+ # plt.subplot(131)
266
+ # plt.imshow(img2)
267
+ # # plt.show()
268
+ # plt.subplot(132)
269
+ # plt.imshow(img3)
270
+ # # plt.show()
271
+ img4 = np.multiply(img2, img3)
272
+ # sum_img4 = np.sum(img4)
273
+ # lam = sum_img4 / (w * h)
274
+
275
+ # plt.subplot(133)
276
+ # plt.imshow(img4)
277
+ # plt.show()
278
+ return img4 #, lam
279
+
280
+ # a = _get_smoothborder_mask(0.5, 256, 256)
281
+
282
+
283
+ def _get_cutmix_mask(cx, cy, Cut_h, Cut_w, h, w):
284
+ lam = np.random.beta(1, 1)
285
+
286
+ bbx1 = np.clip(cx - Cut_w // 2, 0, w) # top left x
287
+ bby1 = np.clip(cy - Cut_h // 2, 0, h) # top left y
288
+ bbx2 = np.clip(bbx1 + Cut_w, 0, w) # bottom right x
289
+ bby2 = np.clip(bby1 + Cut_h, 0, h) # bottom right y
290
+
291
+ img = np.zeros((w, h))
292
+ img2 = np.ones_like(img)
293
+ img[bby1:bby2, bbx1:bbx2] = img2[bby1:bby2, bbx1:bbx2]
294
+
295
+ return img #, lam
296
+
297
+ # a = _get_cutmix_mask(100, 100, 15, 30, 256, 256)
298
+ # a = _get_smoothborder_mask(100, 100, 15, 30, 256, 256)
IPAD/model/reconstruction_model.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from functools import reduce
3
+ from operator import mul
4
+ import torch
5
+
6
+ class Reconstruction3DEncoder(nn.Module):
7
+ def __init__(self, chnum_in):
8
+ super(Reconstruction3DEncoder, self).__init__()
9
+
10
+ # Dong Gong's paper code
11
+ self.chnum_in = chnum_in
12
+ feature_num = 128
13
+ feature_num_2 = 96
14
+ feature_num_x2 = 256
15
+ self.encoder = nn.Sequential(
16
+ nn.Conv3d(self.chnum_in, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
17
+ nn.BatchNorm3d(feature_num_2),
18
+ nn.LeakyReLU(0.2, inplace=True),
19
+ nn.Conv3d(feature_num_2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
20
+ nn.BatchNorm3d(feature_num),
21
+ nn.LeakyReLU(0.2, inplace=True),
22
+ nn.Conv3d(feature_num, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
23
+ nn.BatchNorm3d(feature_num_x2),
24
+ nn.LeakyReLU(0.2, inplace=True),
25
+ nn.Conv3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
26
+ nn.BatchNorm3d(feature_num_x2),
27
+ nn.LeakyReLU(0.2, inplace=True)
28
+ )
29
+
30
+ def forward(self, x):
31
+ x = self.encoder(x)
32
+ return x
33
+
34
+
35
+ class Reconstruction3DDecoder(nn.Module):
36
+ def __init__(self, chnum_in):
37
+ super(Reconstruction3DDecoder, self).__init__()
38
+
39
+ # Dong Gong's paper code + Tanh
40
+ self.chnum_in = chnum_in
41
+ feature_num = 128
42
+ feature_num_2 = 96
43
+ feature_num_x2 = 256
44
+
45
+ self.decoder = nn.Sequential(
46
+ nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
47
+ output_padding=(1, 1, 1)),
48
+ nn.BatchNorm3d(feature_num_x2),
49
+ nn.LeakyReLU(0.2, inplace=True),
50
+ nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
51
+ output_padding=(1, 1, 1)),
52
+ nn.BatchNorm3d(feature_num),
53
+ nn.LeakyReLU(0.2, inplace=True),
54
+ nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
55
+ output_padding=(1, 1, 1)),
56
+ nn.BatchNorm3d(feature_num_2),
57
+ nn.LeakyReLU(0.2, inplace=True),
58
+ nn.ConvTranspose3d(feature_num_2, self.chnum_in, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
59
+ output_padding=(0, 1, 1)),
60
+ nn.Tanh()
61
+ )
62
+
63
+ def forward(self, x):
64
+ x = self.decoder(x)
65
+ return x
66
+
67
+
68
+ class VST3DDecoder(nn.Module):
69
+ def __init__(self, chnum_out):
70
+ super(VST3DDecoder, self).__init__()
71
+
72
+ # Dong Gong's paper code + Tanh
73
+ self.chnum_out = chnum_out
74
+ feature_num = 128
75
+ feature_num_2 = 96
76
+ feature_num_x2 = 256
77
+ feature_num_in = 768
78
+ self.transformer_decoder = nn.Sequential(
79
+ # (4,768,4,8,8)
80
+ nn.ConvTranspose3d(feature_num_in, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
81
+ output_padding=(1, 1, 1)),
82
+ nn.BatchNorm3d(feature_num_x2),
83
+ nn.LeakyReLU(0.2, inplace=True),
84
+ # (4,256,4,16,16)
85
+ nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
86
+ output_padding=(1, 1, 1)),
87
+ nn.BatchNorm3d(feature_num_x2),
88
+ nn.LeakyReLU(0.2, inplace=True),
89
+ nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
90
+ output_padding=(0, 1, 1)),
91
+ nn.BatchNorm3d(feature_num),
92
+ nn.LeakyReLU(0.2, inplace=True),
93
+ nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
94
+ output_padding=(0, 1, 1)),
95
+ nn.BatchNorm3d(feature_num_2),
96
+ nn.LeakyReLU(0.2, inplace=True),
97
+ nn.ConvTranspose3d(feature_num_2, self.chnum_out, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
98
+ output_padding=(0, 1, 1)),
99
+ nn.Tanh()
100
+ )
101
+
102
+ def forward(self, x):
103
+ x = self.transformer_decoder(x)
104
+ return x
IPAD/model/utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import OrderedDict
3
+ import os
4
+ import glob
5
+ import cv2
6
+ import torch.utils.data as data
7
+ import random
8
+ from PIL import Image
9
+
10
+
11
+ rng = np.random.RandomState(2020)
12
+
13
+ def np_load_frame(filename, resize_height, resize_width, grayscale=False):
14
+ grayscale = False
15
+ """
16
+ Load image path and convert it to numpy.ndarray. Notes that the color channels are BGR and the color space
17
+ is normalized from [0, 255] to [-1, 1].
18
+
19
+ :param filename: the full path of image
20
+ :param resize_height: resized height
21
+ :param resize_width: resized width
22
+ :return: numpy.ndarray
23
+ """
24
+ if grayscale:
25
+ image_decoded = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
26
+ else:
27
+ image_decoded = cv2.imread(filename)
28
+ image_resized = cv2.resize(image_decoded, (resize_width, resize_height))
29
+ # image_resized = np.copy(image_decoded)
30
+ image_resized = image_resized.astype(dtype=np.float32)
31
+ image_resized = (image_resized / 127.5) - 1.0
32
+ return image_resized
33
+
34
+
35
+
36
+ class Reconstruction3DDataLoader(data.Dataset):
37
+ def __init__(self, video_folder, transform, resize_height, resize_width, num_frames=16,
38
+ img_extension='.jpg', dataset='ped2', jump=[2], hold=[2], return_normal_seq=False):
39
+ self.dir = video_folder
40
+ self.transform = transform
41
+ self.videos = OrderedDict()
42
+ self._resize_height = resize_height
43
+ self._resize_width = resize_width
44
+ self._num_frames = num_frames
45
+
46
+ self.extension = img_extension
47
+ self.dataset = dataset
48
+
49
+ self.setup()
50
+ self.samples, self.background_models = self.get_all_samples()
51
+
52
+ self.jump = jump
53
+ self.hold = hold
54
+ self.return_normal_seq = return_normal_seq # for fast and slow moving
55
+
56
+ def setup(self):
57
+ videos = glob.glob(os.path.join(self.dir, '*/'))
58
+ for video in sorted(videos):
59
+ print(video)
60
+ video_name = video.split('/')[-2]
61
+ self.videos[video_name] = {}
62
+ self.videos[video_name]['path'] = video
63
+ self.videos[video_name]['frame'] = glob.glob(os.path.join(video, '*' + self.extension))
64
+ self.videos[video_name]['frame'].sort()
65
+ self.videos[video_name]['length'] = len(self.videos[video_name]['frame'])
66
+
67
+ def get_all_samples(self):
68
+ frames = []
69
+ background_models = []
70
+ videos = glob.glob(os.path.join(self.dir, '*/'))
71
+ for video in sorted(videos):
72
+ video_name = video.split('/')[-2]
73
+
74
+ for i in range(len(self.videos[video_name]['frame']) - self._num_frames + 1):
75
+ frames.append(self.videos[video_name]['frame'][i])
76
+ # background_models.append(bg_filename)
77
+
78
+ return frames, background_models
79
+
80
+ def __getitem__(self, index):
81
+ # index = 8
82
+ video_name = self.samples[index].split('/')[-2]
83
+ if self.dataset == 'shanghai' and 'training' in self.samples[index]:
84
+ frame_name = int(self.samples[index].split('/')[-1].split('.')[-2]) - 1
85
+ else:
86
+ frame_name = int(self.samples[index].split('/')[-1].split('.')[-2])
87
+
88
+ batch = []
89
+ for i in range(self._num_frames):
90
+ image = np_load_frame(self.videos[video_name]['frame'][frame_name + i], self._resize_height,
91
+ self._resize_width, grayscale=True)
92
+ if self.transform is not None:
93
+ batch.append(self.transform(image))
94
+ # batch:len=16 ,batch[0]:torch(3,256,256)
95
+ img = OrderedDict()
96
+ img['batch'] = np.stack(batch, axis=1)
97
+ img['index'] = frame_name*200//len(self.videos[video_name]['frame'])
98
+ # return np.stack(batch, axis=1)
99
+ return img
100
+
101
+ def __len__(self):
102
+ return len(self.samples)
103
+
104
+
105
+ class Reconstruction3DDataLoaderJump(Reconstruction3DDataLoader):
106
+ def __getitem__(self, index):
107
+ # index = 8
108
+ video_name = self.samples[index].split('/')[-2]
109
+ if self.dataset == 'shanghai' and 'training' in self.samples[index]: # bcos my shanghai's start from 1
110
+ frame_name = int(self.samples[index].split('/')[-1].split('.')[-2]) - 1
111
+ else:
112
+ frame_name = int(self.samples[index].split('/')[-1].split('.')[-2])
113
+
114
+ batch = []
115
+ normal_batch = []
116
+ jump = random.choice(self.jump)
117
+
118
+ retry = 0
119
+ while len(self.videos[video_name]['frame']) < frame_name + (self._num_frames-1) * jump and retry < 10:
120
+ # reselect the frame_name
121
+ frame_name = np.random.randint(len(self.videos[video_name]['frame']))
122
+ retry += 1
123
+
124
+ for i in range(self._num_frames):
125
+ image = np_load_frame(self.videos[video_name]['frame'][min(frame_name + i*jump, len(self.videos[video_name]['frame'])-1)], self._resize_height,
126
+ self._resize_width, grayscale=True)
127
+
128
+ if self.transform is not None:
129
+ batch.append(self.transform(image))
130
+
131
+ if self.return_normal_seq:
132
+ for i in range(self._num_frames):
133
+ image = np_load_frame(self.videos[video_name]['frame'][min(frame_name + i, len(self.videos[video_name]['frame'])-1)], self._resize_height,
134
+ self._resize_width, grayscale=True)
135
+
136
+ if self.transform is not None:
137
+ normal_batch.append(self.transform(image))
138
+ return np.stack(batch, axis=1), np.stack(normal_batch, axis=1)
139
+
140
+ else:
141
+ return np.stack(batch, axis=1), normal_batch
142
+
IPAD/model/video_swin_transformer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .reconstruction_model import Reconstruction3DEncoder, Reconstruction3DDecoder, VST3DDecoder
3
+ from .VST_block import SwinTransformer3D
4
+ from einops import rearrange
5
+ from model import MemModule
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+
9
+ class VST(torch.nn.Module):
10
+ def __init__(self, mem_dim=2000, shrink_thres=0.0025): # for reconstruction
11
+ super(VST, self).__init__()
12
+ self.reconstruction = True
13
+ # self.chnum_in = chnum_in
14
+
15
+ # self.encoder = Reconstruction3DEncoder(chnum_in=1) # black and white
16
+ # self.decoder = Reconstruction3DDecoder(chnum_in=1) # black and white
17
+ self.transformer_encoder = SwinTransformer3D()
18
+ self.mem_rep = MemModule(mem_dim=mem_dim, fea_dim=768, shrink_thres=shrink_thres)
19
+ self.period = nn.Sequential(
20
+ nn.Conv3d(768, 768, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
21
+ nn.BatchNorm3d(768),
22
+ nn.LeakyReLU(0.2, inplace=True),
23
+ # (batch_size,256,4,4,4)
24
+ nn.Flatten(1),
25
+ nn.Linear(768*4*4*4,4096),
26
+ nn.ReLU(),
27
+ nn.Linear(4096,2048),
28
+ nn.ReLU(),
29
+ nn.Linear(2048,200),
30
+ )
31
+ self.transformer_decoder = VST3DDecoder(chnum_out=3)
32
+ # self.encoder = Reconstruction3DEncoder(chnum_in=3) # RGB
33
+ # self.decoder = Reconstruction3DDecoder(chnum_in=3) # RGB
34
+
35
+ def forward(self, x):
36
+
37
+ feature = self.transformer_encoder(x)
38
+ #feature (batch_size,768,4,8,8)
39
+ recon_index = self.period(feature)
40
+ # print(recon_index[0])
41
+ res_mem = self.mem_rep(feature, recon_index)
42
+ feature = res_mem['output']
43
+ att = res_mem['att']
44
+ output = self.transformer_decoder(feature.clone())
45
+
46
+ return {'output': output, 'att': att, 'recon_index': recon_index}
47
+
48
+