Spaces:
Sleeping
Sleeping
Add IPAD model implementation
Browse files- IPAD/__init__.py +5 -0
- IPAD/model/VST_block.py +568 -0
- IPAD/model/__init__.py +5 -0
- IPAD/model/autoencoder.py +24 -0
- IPAD/model/entropy_loss.py +50 -0
- IPAD/model/memae_3dconv.py +53 -0
- IPAD/model/memory_module.py +112 -0
- IPAD/model/pseudoanomaly_utils.py +298 -0
- IPAD/model/reconstruction_model.py +104 -0
- IPAD/model/utils.py +142 -0
- IPAD/model/video_swin_transformer.py +48 -0
IPAD/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# IPAD Model Package
|
| 2 |
+
from .model.video_swin_transformer import VST
|
| 3 |
+
from .model.memory_module import MemModule
|
| 4 |
+
|
| 5 |
+
__all__ = ['VST', 'MemModule']
|
IPAD/model/VST_block.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.utils.checkpoint as checkpoint
|
| 5 |
+
import numpy as np
|
| 6 |
+
from timm.models.layers import DropPath, trunc_normal_
|
| 7 |
+
|
| 8 |
+
from functools import reduce, lru_cache
|
| 9 |
+
from operator import mul
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Mlp(nn.Module):
|
| 16 |
+
""" Multilayer perceptron."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 19 |
+
super().__init__()
|
| 20 |
+
out_features = out_features or in_features
|
| 21 |
+
hidden_features = hidden_features or in_features
|
| 22 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 23 |
+
self.act = act_layer()
|
| 24 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 25 |
+
self.drop = nn.Dropout(drop)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
x = self.fc1(x)
|
| 29 |
+
x = self.act(x)
|
| 30 |
+
x = self.drop(x)
|
| 31 |
+
x = self.fc2(x)
|
| 32 |
+
x = self.drop(x)
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def window_partition(x, window_size):
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
x: (B, D, H, W, C)
|
| 40 |
+
window_size (tuple[int]): window size
|
| 41 |
+
Returns:
|
| 42 |
+
windows: (B*num_windows, window_size*window_size, C)
|
| 43 |
+
"""
|
| 44 |
+
B, D, H, W, C = x.shape
|
| 45 |
+
x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C)
|
| 46 |
+
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
|
| 47 |
+
return windows
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def window_reverse(windows, window_size, B, D, H, W):
|
| 51 |
+
"""
|
| 52 |
+
Args:
|
| 53 |
+
windows: (B*num_windows, window_size, window_size, C)
|
| 54 |
+
window_size (tuple[int]): Window size
|
| 55 |
+
H (int): Height of image
|
| 56 |
+
W (int): Width of image
|
| 57 |
+
Returns:
|
| 58 |
+
x: (B, D, H, W, C)
|
| 59 |
+
"""
|
| 60 |
+
x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1)
|
| 61 |
+
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_window_size(x_size, window_size, shift_size=None):
|
| 68 |
+
use_window_size = list(window_size)
|
| 69 |
+
if shift_size is not None:
|
| 70 |
+
use_shift_size = list(shift_size)
|
| 71 |
+
for i in range(len(x_size)):
|
| 72 |
+
if x_size[i] <= window_size[i]:
|
| 73 |
+
use_window_size[i] = x_size[i]
|
| 74 |
+
if shift_size is not None:
|
| 75 |
+
use_shift_size[i] = 0
|
| 76 |
+
|
| 77 |
+
if shift_size is None:
|
| 78 |
+
return tuple(use_window_size)
|
| 79 |
+
else:
|
| 80 |
+
return tuple(use_window_size), tuple(use_shift_size)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class WindowAttention3D(nn.Module):
|
| 84 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 85 |
+
It supports both of shifted and non-shifted window.
|
| 86 |
+
Args:
|
| 87 |
+
dim (int): Number of input channels.
|
| 88 |
+
window_size (tuple[int]): The temporal length, height and width of the window.
|
| 89 |
+
num_heads (int): Number of attention heads.
|
| 90 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 91 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 92 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 93 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 97 |
+
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.dim = dim
|
| 100 |
+
self.window_size = window_size # Wd, Wh, Ww
|
| 101 |
+
self.num_heads = num_heads
|
| 102 |
+
head_dim = dim // num_heads
|
| 103 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 104 |
+
|
| 105 |
+
# define a parameter table of relative position bias
|
| 106 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 107 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
|
| 108 |
+
|
| 109 |
+
# get pair-wise relative position index for each token inside the window
|
| 110 |
+
coords_d = torch.arange(self.window_size[0])
|
| 111 |
+
coords_h = torch.arange(self.window_size[1])
|
| 112 |
+
coords_w = torch.arange(self.window_size[2])
|
| 113 |
+
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
|
| 114 |
+
coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
|
| 115 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
|
| 116 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
|
| 117 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 118 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 119 |
+
relative_coords[:, :, 2] += self.window_size[2] - 1
|
| 120 |
+
|
| 121 |
+
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
|
| 122 |
+
relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
|
| 123 |
+
relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
|
| 124 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 125 |
+
|
| 126 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 127 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 128 |
+
self.proj = nn.Linear(dim, dim)
|
| 129 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 130 |
+
|
| 131 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 132 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 133 |
+
|
| 134 |
+
def forward(self, x, mask=None):
|
| 135 |
+
""" Forward function.
|
| 136 |
+
Args:
|
| 137 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 138 |
+
mask: (0/-inf) mask with shape of (num_windows, N, N) or None
|
| 139 |
+
"""
|
| 140 |
+
B_, N, C = x.shape
|
| 141 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 142 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C
|
| 143 |
+
|
| 144 |
+
q = q * self.scale
|
| 145 |
+
attn = q @ k.transpose(-2, -1)
|
| 146 |
+
|
| 147 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(
|
| 148 |
+
N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH
|
| 149 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww
|
| 150 |
+
attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
|
| 151 |
+
|
| 152 |
+
if mask is not None:
|
| 153 |
+
nW = mask.shape[0]
|
| 154 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
| 155 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 156 |
+
attn = self.softmax(attn)
|
| 157 |
+
else:
|
| 158 |
+
attn = self.softmax(attn)
|
| 159 |
+
|
| 160 |
+
attn = self.attn_drop(attn)
|
| 161 |
+
|
| 162 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 163 |
+
x = self.proj(x)
|
| 164 |
+
x = self.proj_drop(x)
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class SwinTransformerBlock3D(nn.Module):
|
| 169 |
+
""" Swin Transformer Block.
|
| 170 |
+
Args:
|
| 171 |
+
dim (int): Number of input channels.
|
| 172 |
+
num_heads (int): Number of attention heads.
|
| 173 |
+
window_size (tuple[int]): Window size.
|
| 174 |
+
shift_size (tuple[int]): Shift size for SW-MSA.
|
| 175 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 176 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 177 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 178 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 179 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 180 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 181 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 182 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),
|
| 186 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
| 187 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.dim = dim
|
| 190 |
+
self.num_heads = num_heads
|
| 191 |
+
self.window_size = window_size
|
| 192 |
+
self.shift_size = shift_size
|
| 193 |
+
self.mlp_ratio = mlp_ratio
|
| 194 |
+
self.use_checkpoint=use_checkpoint
|
| 195 |
+
|
| 196 |
+
assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size"
|
| 197 |
+
assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size"
|
| 198 |
+
assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size"
|
| 199 |
+
|
| 200 |
+
self.norm1 = norm_layer(dim)
|
| 201 |
+
self.attn = WindowAttention3D(
|
| 202 |
+
dim, window_size=self.window_size, num_heads=num_heads,
|
| 203 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 204 |
+
|
| 205 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 206 |
+
self.norm2 = norm_layer(dim)
|
| 207 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 208 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 209 |
+
|
| 210 |
+
def forward_part1(self, x, mask_matrix):
|
| 211 |
+
B, D, H, W, C = x.shape
|
| 212 |
+
window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size)
|
| 213 |
+
|
| 214 |
+
x = self.norm1(x)
|
| 215 |
+
# pad feature maps to multiples of window size
|
| 216 |
+
pad_l = pad_t = pad_d0 = 0
|
| 217 |
+
pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]
|
| 218 |
+
pad_b = (window_size[1] - H % window_size[1]) % window_size[1]
|
| 219 |
+
pad_r = (window_size[2] - W % window_size[2]) % window_size[2]
|
| 220 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
|
| 221 |
+
_, Dp, Hp, Wp, _ = x.shape
|
| 222 |
+
# cyclic shift
|
| 223 |
+
if any(i > 0 for i in shift_size):
|
| 224 |
+
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
|
| 225 |
+
attn_mask = mask_matrix
|
| 226 |
+
else:
|
| 227 |
+
shifted_x = x
|
| 228 |
+
attn_mask = None
|
| 229 |
+
# partition windows
|
| 230 |
+
x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C
|
| 231 |
+
# W-MSA/SW-MSA
|
| 232 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C
|
| 233 |
+
# merge windows
|
| 234 |
+
attn_windows = attn_windows.view(-1, *(window_size+(C,)))
|
| 235 |
+
shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C
|
| 236 |
+
# reverse cyclic shift
|
| 237 |
+
if any(i > 0 for i in shift_size):
|
| 238 |
+
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
|
| 239 |
+
else:
|
| 240 |
+
x = shifted_x
|
| 241 |
+
|
| 242 |
+
if pad_d1 >0 or pad_r > 0 or pad_b > 0:
|
| 243 |
+
x = x[:, :D, :H, :W, :].contiguous()
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
def forward_part2(self, x):
|
| 247 |
+
return self.drop_path(self.mlp(self.norm2(x)))
|
| 248 |
+
|
| 249 |
+
def forward(self, x, mask_matrix):
|
| 250 |
+
""" Forward function.
|
| 251 |
+
Args:
|
| 252 |
+
x: Input feature, tensor size (B, D, H, W, C).
|
| 253 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
shortcut = x
|
| 257 |
+
if self.use_checkpoint:
|
| 258 |
+
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
|
| 259 |
+
else:
|
| 260 |
+
x = self.forward_part1(x, mask_matrix)
|
| 261 |
+
x = shortcut + self.drop_path(x)
|
| 262 |
+
|
| 263 |
+
if self.use_checkpoint:
|
| 264 |
+
x = x + checkpoint.checkpoint(self.forward_part2, x)
|
| 265 |
+
else:
|
| 266 |
+
x = x + self.forward_part2(x)
|
| 267 |
+
|
| 268 |
+
return x
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class PatchMerging(nn.Module):
|
| 272 |
+
""" Patch Merging Layer
|
| 273 |
+
Args:
|
| 274 |
+
dim (int): Number of input channels.
|
| 275 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 276 |
+
"""
|
| 277 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.dim = dim
|
| 280 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 281 |
+
self.norm = norm_layer(4 * dim)
|
| 282 |
+
|
| 283 |
+
def forward(self, x):
|
| 284 |
+
""" Forward function.
|
| 285 |
+
Args:
|
| 286 |
+
x: Input feature, tensor size (B, D, H, W, C).
|
| 287 |
+
"""
|
| 288 |
+
B, D, H, W, C = x.shape
|
| 289 |
+
|
| 290 |
+
# padding
|
| 291 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 292 |
+
if pad_input:
|
| 293 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 294 |
+
|
| 295 |
+
x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C
|
| 296 |
+
x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C
|
| 297 |
+
x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C
|
| 298 |
+
x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C
|
| 299 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C
|
| 300 |
+
|
| 301 |
+
x = self.norm(x)
|
| 302 |
+
x = self.reduction(x)
|
| 303 |
+
|
| 304 |
+
return x
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# cache each stage results
|
| 308 |
+
@lru_cache()
|
| 309 |
+
def compute_mask(D, H, W, window_size, shift_size, device):
|
| 310 |
+
img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1
|
| 311 |
+
cnt = 0
|
| 312 |
+
for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):
|
| 313 |
+
for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):
|
| 314 |
+
for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):
|
| 315 |
+
img_mask[:, d, h, w, :] = cnt
|
| 316 |
+
cnt += 1
|
| 317 |
+
mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1
|
| 318 |
+
mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2]
|
| 319 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 320 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 321 |
+
return attn_mask
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class BasicLayer(nn.Module):
|
| 325 |
+
""" A basic Swin Transformer layer for one stage.
|
| 326 |
+
Args:
|
| 327 |
+
dim (int): Number of feature channels
|
| 328 |
+
depth (int): Depths of this stage.
|
| 329 |
+
num_heads (int): Number of attention head.
|
| 330 |
+
window_size (tuple[int]): Local window size. Default: (1,7,7).
|
| 331 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 332 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 333 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 334 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 335 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 336 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 337 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 338 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
def __init__(self,
|
| 342 |
+
dim,
|
| 343 |
+
depth,
|
| 344 |
+
num_heads,
|
| 345 |
+
window_size=(1,7,7),
|
| 346 |
+
mlp_ratio=4.,
|
| 347 |
+
qkv_bias=False,
|
| 348 |
+
qk_scale=None,
|
| 349 |
+
drop=0.,
|
| 350 |
+
attn_drop=0.,
|
| 351 |
+
drop_path=0.,
|
| 352 |
+
norm_layer=nn.LayerNorm,
|
| 353 |
+
downsample=None,
|
| 354 |
+
use_checkpoint=False):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.window_size = window_size
|
| 357 |
+
self.shift_size = tuple(i // 2 for i in window_size)
|
| 358 |
+
self.depth = depth
|
| 359 |
+
self.use_checkpoint = use_checkpoint
|
| 360 |
+
|
| 361 |
+
# build blocks
|
| 362 |
+
self.blocks = nn.ModuleList([
|
| 363 |
+
SwinTransformerBlock3D(
|
| 364 |
+
dim=dim,
|
| 365 |
+
num_heads=num_heads,
|
| 366 |
+
window_size=window_size,
|
| 367 |
+
shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size,
|
| 368 |
+
mlp_ratio=mlp_ratio,
|
| 369 |
+
qkv_bias=qkv_bias,
|
| 370 |
+
qk_scale=qk_scale,
|
| 371 |
+
drop=drop,
|
| 372 |
+
attn_drop=attn_drop,
|
| 373 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 374 |
+
norm_layer=norm_layer,
|
| 375 |
+
use_checkpoint=use_checkpoint,
|
| 376 |
+
)
|
| 377 |
+
for i in range(depth)])
|
| 378 |
+
|
| 379 |
+
self.downsample = downsample
|
| 380 |
+
if self.downsample is not None:
|
| 381 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 382 |
+
|
| 383 |
+
def forward(self, x):
|
| 384 |
+
""" Forward function.
|
| 385 |
+
Args:
|
| 386 |
+
x: Input feature, tensor size (B, C, D, H, W).
|
| 387 |
+
"""
|
| 388 |
+
# calculate attention mask for SW-MSA
|
| 389 |
+
B, C, D, H, W = x.shape
|
| 390 |
+
window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)
|
| 391 |
+
x = rearrange(x, 'b c d h w -> b d h w c')
|
| 392 |
+
Dp = int(np.ceil(D / window_size[0])) * window_size[0]
|
| 393 |
+
Hp = int(np.ceil(H / window_size[1])) * window_size[1]
|
| 394 |
+
Wp = int(np.ceil(W / window_size[2])) * window_size[2]
|
| 395 |
+
attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
|
| 396 |
+
for blk in self.blocks:
|
| 397 |
+
x = blk(x, attn_mask)
|
| 398 |
+
x = x.view(B, D, H, W, -1)
|
| 399 |
+
|
| 400 |
+
if self.downsample is not None:
|
| 401 |
+
x = self.downsample(x)
|
| 402 |
+
x = rearrange(x, 'b d h w c -> b c d h w')
|
| 403 |
+
return x
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class PatchEmbed3D(nn.Module):
|
| 407 |
+
""" Video to Patch Embedding.
|
| 408 |
+
Args:
|
| 409 |
+
patch_size (int): Patch token size. Default: (2,4,4).
|
| 410 |
+
in_chans (int): Number of input video channels. Default: 3.
|
| 411 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 412 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 413 |
+
"""
|
| 414 |
+
def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
|
| 415 |
+
super().__init__()
|
| 416 |
+
self.patch_size = patch_size
|
| 417 |
+
|
| 418 |
+
self.in_chans = in_chans
|
| 419 |
+
self.embed_dim = embed_dim
|
| 420 |
+
|
| 421 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 422 |
+
if norm_layer is not None:
|
| 423 |
+
self.norm = norm_layer(embed_dim)
|
| 424 |
+
else:
|
| 425 |
+
self.norm = None
|
| 426 |
+
|
| 427 |
+
def forward(self, x):
|
| 428 |
+
"""Forward function."""
|
| 429 |
+
# padding
|
| 430 |
+
_, _, D, H, W = x.size()
|
| 431 |
+
if W % self.patch_size[2] != 0:
|
| 432 |
+
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
| 433 |
+
if H % self.patch_size[1] != 0:
|
| 434 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
| 435 |
+
if D % self.patch_size[0] != 0:
|
| 436 |
+
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
| 437 |
+
|
| 438 |
+
x = self.proj(x) # B C D Wh Ww
|
| 439 |
+
if self.norm is not None:
|
| 440 |
+
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
| 441 |
+
x = x.flatten(2).transpose(1, 2)
|
| 442 |
+
x = self.norm(x)
|
| 443 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
| 444 |
+
|
| 445 |
+
return x
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class SwinTransformer3D(nn.Module):
|
| 449 |
+
""" Swin Transformer backbone.
|
| 450 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 451 |
+
https://arxiv.org/pdf/2103.14030
|
| 452 |
+
Args:
|
| 453 |
+
patch_size (int | tuple(int)): Patch size. Default: (4,4,4).
|
| 454 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 455 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 456 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 457 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 458 |
+
window_size (int): Window size. Default: 7.
|
| 459 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 460 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
|
| 461 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 462 |
+
drop_rate (float): Dropout rate.
|
| 463 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 464 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 465 |
+
norm_layer: Normalization layer. Default: nn.LayerNorm.
|
| 466 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: False.
|
| 467 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 468 |
+
-1 means not freezing any parameters.
|
| 469 |
+
"""
|
| 470 |
+
|
| 471 |
+
def __init__(self,
|
| 472 |
+
pretrained=None,
|
| 473 |
+
pretrained2d=True,
|
| 474 |
+
patch_size=(4,4,4),
|
| 475 |
+
in_chans=3,
|
| 476 |
+
embed_dim=96,
|
| 477 |
+
depths=[2, 2, 6, 2],
|
| 478 |
+
num_heads=[3, 6, 12, 24],
|
| 479 |
+
window_size=(2,7,7),
|
| 480 |
+
mlp_ratio=4.,
|
| 481 |
+
qkv_bias=True,
|
| 482 |
+
qk_scale=None,
|
| 483 |
+
drop_rate=0.,
|
| 484 |
+
attn_drop_rate=0.,
|
| 485 |
+
drop_path_rate=0.2,
|
| 486 |
+
norm_layer=nn.LayerNorm,
|
| 487 |
+
patch_norm=False,
|
| 488 |
+
frozen_stages=-1,
|
| 489 |
+
use_checkpoint=False):
|
| 490 |
+
super().__init__()
|
| 491 |
+
|
| 492 |
+
self.pretrained = pretrained
|
| 493 |
+
self.pretrained2d = pretrained2d
|
| 494 |
+
self.num_layers = len(depths)
|
| 495 |
+
self.embed_dim = embed_dim
|
| 496 |
+
self.patch_norm = patch_norm
|
| 497 |
+
self.frozen_stages = frozen_stages
|
| 498 |
+
self.window_size = window_size
|
| 499 |
+
self.patch_size = patch_size
|
| 500 |
+
|
| 501 |
+
# split image into non-overlapping patches
|
| 502 |
+
self.patch_embed = PatchEmbed3D(
|
| 503 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
| 504 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 505 |
+
|
| 506 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 507 |
+
|
| 508 |
+
# stochastic depth
|
| 509 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 510 |
+
|
| 511 |
+
# build layers
|
| 512 |
+
self.layers = nn.ModuleList()
|
| 513 |
+
for i_layer in range(self.num_layers):
|
| 514 |
+
layer = BasicLayer(
|
| 515 |
+
dim=int(embed_dim * 2**i_layer),
|
| 516 |
+
depth=depths[i_layer],
|
| 517 |
+
num_heads=num_heads[i_layer],
|
| 518 |
+
window_size=window_size,
|
| 519 |
+
mlp_ratio=mlp_ratio,
|
| 520 |
+
qkv_bias=qkv_bias,
|
| 521 |
+
qk_scale=qk_scale,
|
| 522 |
+
drop=drop_rate,
|
| 523 |
+
attn_drop=attn_drop_rate,
|
| 524 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 525 |
+
norm_layer=norm_layer,
|
| 526 |
+
downsample=PatchMerging if i_layer<self.num_layers-1 else None,
|
| 527 |
+
use_checkpoint=use_checkpoint)
|
| 528 |
+
self.layers.append(layer)
|
| 529 |
+
|
| 530 |
+
self.num_features = int(embed_dim * 2**(self.num_layers-1))
|
| 531 |
+
|
| 532 |
+
# add a norm layer for each output
|
| 533 |
+
self.norm = norm_layer(self.num_features)
|
| 534 |
+
|
| 535 |
+
self._freeze_stages()
|
| 536 |
+
|
| 537 |
+
def _freeze_stages(self):
|
| 538 |
+
if self.frozen_stages >= 0:
|
| 539 |
+
self.patch_embed.eval()
|
| 540 |
+
for param in self.patch_embed.parameters():
|
| 541 |
+
param.requires_grad = False
|
| 542 |
+
|
| 543 |
+
if self.frozen_stages >= 1:
|
| 544 |
+
self.pos_drop.eval()
|
| 545 |
+
for i in range(0, self.frozen_stages):
|
| 546 |
+
m = self.layers[i]
|
| 547 |
+
m.eval()
|
| 548 |
+
for param in m.parameters():
|
| 549 |
+
param.requires_grad = False
|
| 550 |
+
|
| 551 |
+
def forward(self, x):
|
| 552 |
+
"""Forward function."""
|
| 553 |
+
x = self.patch_embed(x)
|
| 554 |
+
|
| 555 |
+
x = self.pos_drop(x)
|
| 556 |
+
|
| 557 |
+
for layer in self.layers:
|
| 558 |
+
|
| 559 |
+
x = layer(x.contiguous())
|
| 560 |
+
x = rearrange(x, 'n c d h w -> n d h w c')
|
| 561 |
+
x = self.norm(x)
|
| 562 |
+
x = rearrange(x, 'n d h w c -> n c d h w')
|
| 563 |
+
return x
|
| 564 |
+
|
| 565 |
+
# def train(self, mode=True):
|
| 566 |
+
# """Convert the model into training mode while keep layers freezed."""
|
| 567 |
+
# super(SwinTransformer3D, self).train(mode)
|
| 568 |
+
# self._freeze_stages()
|
IPAD/model/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import, print_function
|
| 2 |
+
|
| 3 |
+
from model.memory_module import *
|
| 4 |
+
from model.memae_3dconv import *
|
| 5 |
+
from model.entropy_loss import *
|
IPAD/model/autoencoder.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .reconstruction_model import Reconstruction3DEncoder, Reconstruction3DDecoder
|
| 3 |
+
|
| 4 |
+
class convAE(torch.nn.Module):
|
| 5 |
+
def __init__(self): # for reconstruction
|
| 6 |
+
super(convAE, self).__init__()
|
| 7 |
+
|
| 8 |
+
self.reconstruction = True
|
| 9 |
+
|
| 10 |
+
# self.encoder = Reconstruction3DEncoder(chnum_in=1) # black and white
|
| 11 |
+
# self.decoder = Reconstruction3DDecoder(chnum_in=1) # black and white
|
| 12 |
+
self.encoder = Reconstruction3DEncoder(chnum_in=3) # RGB
|
| 13 |
+
self.decoder = Reconstruction3DDecoder(chnum_in=3) # RGB
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
# print(x.shape)
|
| 17 |
+
fea = self.encoder(x)
|
| 18 |
+
# print(fea.shape)
|
| 19 |
+
output = self.decoder(fea.clone())
|
| 20 |
+
# print(output.shape)
|
| 21 |
+
|
| 22 |
+
return output
|
| 23 |
+
|
| 24 |
+
|
IPAD/model/entropy_loss.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import, print_function
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def feature_map_permute(input):
|
| 7 |
+
s = input.data.shape
|
| 8 |
+
l = len(s)
|
| 9 |
+
|
| 10 |
+
# permute feature channel to the last:
|
| 11 |
+
# NxCxDxHxW --> NxDxHxW x C
|
| 12 |
+
if l == 2:
|
| 13 |
+
x = input # NxC
|
| 14 |
+
elif l == 3:
|
| 15 |
+
x = input.permute(0, 2, 1)
|
| 16 |
+
elif l == 4:
|
| 17 |
+
x = input.permute(0, 2, 3, 1)
|
| 18 |
+
elif l == 5:
|
| 19 |
+
x = input.permute(0, 2, 3, 4, 1)
|
| 20 |
+
else:
|
| 21 |
+
x = []
|
| 22 |
+
print('wrong feature map size')
|
| 23 |
+
x = x.contiguous()
|
| 24 |
+
# NxDxHxW x C --> (NxDxHxW) x C
|
| 25 |
+
x = x.view(-1, s[1])
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
class EntropyLoss(nn.Module):
|
| 29 |
+
def __init__(self, eps = 1e-12):
|
| 30 |
+
super(EntropyLoss, self).__init__()
|
| 31 |
+
self.eps = eps
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
b = x * torch.log(x + self.eps)
|
| 35 |
+
b = -1.0 * b.sum(dim=1)
|
| 36 |
+
b = b.mean()
|
| 37 |
+
return b
|
| 38 |
+
|
| 39 |
+
class EntropyLossEncap(nn.Module):
|
| 40 |
+
def __init__(self, eps = 1e-12):
|
| 41 |
+
super(EntropyLossEncap, self).__init__()
|
| 42 |
+
self.eps = eps
|
| 43 |
+
self.entropy_loss = EntropyLoss(eps)
|
| 44 |
+
|
| 45 |
+
def forward(self, input):
|
| 46 |
+
score = feature_map_permute(input)
|
| 47 |
+
ent_loss_val = self.entropy_loss(score)
|
| 48 |
+
return ent_loss_val
|
| 49 |
+
|
| 50 |
+
|
IPAD/model/memae_3dconv.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import, print_function
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
from model import MemModule
|
| 6 |
+
|
| 7 |
+
class AutoEncoderCov3DMem(nn.Module):
|
| 8 |
+
def __init__(self, chnum_in, mem_dim, shrink_thres=0.0025):
|
| 9 |
+
super(AutoEncoderCov3DMem, self).__init__()
|
| 10 |
+
print('AutoEncoderCov3DMem')
|
| 11 |
+
self.chnum_in = chnum_in
|
| 12 |
+
feature_num = 128
|
| 13 |
+
feature_num_2 = 96
|
| 14 |
+
feature_num_x2 = 256
|
| 15 |
+
self.encoder = nn.Sequential(
|
| 16 |
+
nn.Conv3d(self.chnum_in, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
|
| 17 |
+
nn.BatchNorm3d(feature_num_2),
|
| 18 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 19 |
+
nn.Conv3d(feature_num_2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
|
| 20 |
+
nn.BatchNorm3d(feature_num),
|
| 21 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 22 |
+
nn.Conv3d(feature_num, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
|
| 23 |
+
nn.BatchNorm3d(feature_num_x2),
|
| 24 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 25 |
+
nn.Conv3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
|
| 26 |
+
nn.BatchNorm3d(feature_num_x2),
|
| 27 |
+
nn.LeakyReLU(0.2, inplace=True)
|
| 28 |
+
)
|
| 29 |
+
self.mem_rep = MemModule(mem_dim=mem_dim, fea_dim=feature_num_x2, shrink_thres =shrink_thres)
|
| 30 |
+
self.decoder = nn.Sequential(
|
| 31 |
+
nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
|
| 32 |
+
output_padding=(1, 1, 1)),
|
| 33 |
+
nn.BatchNorm3d(feature_num_x2),
|
| 34 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 35 |
+
nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
|
| 36 |
+
output_padding=(1, 1, 1)),
|
| 37 |
+
nn.BatchNorm3d(feature_num),
|
| 38 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 39 |
+
nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
|
| 40 |
+
output_padding=(1, 1, 1)),
|
| 41 |
+
nn.BatchNorm3d(feature_num_2),
|
| 42 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 43 |
+
nn.ConvTranspose3d(feature_num_2, self.chnum_in, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
|
| 44 |
+
output_padding=(0, 1, 1))
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
f = self.encoder(x)
|
| 49 |
+
res_mem = self.mem_rep(f)
|
| 50 |
+
f = res_mem['output']
|
| 51 |
+
att = res_mem['att']
|
| 52 |
+
output = self.decoder(f)
|
| 53 |
+
return {'output': output, 'att': att}
|
IPAD/model/memory_module.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import, print_function
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import math
|
| 5 |
+
from torch.nn.parameter import Parameter
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
#
|
| 10 |
+
class MemoryUnit(nn.Module):
|
| 11 |
+
def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025):
|
| 12 |
+
super(MemoryUnit, self).__init__()
|
| 13 |
+
self.mem_dim = mem_dim
|
| 14 |
+
self.fea_dim = fea_dim
|
| 15 |
+
self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim)) # M x C
|
| 16 |
+
self.bias = None
|
| 17 |
+
self.shrink_thres= shrink_thres
|
| 18 |
+
# self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres)
|
| 19 |
+
|
| 20 |
+
self.reset_parameters()
|
| 21 |
+
|
| 22 |
+
def reset_parameters(self):
|
| 23 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
| 24 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 25 |
+
if self.bias is not None:
|
| 26 |
+
self.bias.data.uniform_(-stdv, stdv)
|
| 27 |
+
|
| 28 |
+
def forward(self, input, period_score):
|
| 29 |
+
# print(input.shape)
|
| 30 |
+
score,indices = torch.max(period_score,dim=1)
|
| 31 |
+
indices = (torch.floor((indices/126)*self.mem_dim).cpu().numpy()).astype(int)
|
| 32 |
+
# # print(indices)
|
| 33 |
+
att_weight = F.linear(input, self.weight) # Fea x Mem^T, (TxC) x (CxM) = TxM
|
| 34 |
+
a = score[i]
|
| 35 |
+
att_weight[:,indices[i]-7:indices[i]+8]=att_weight[:,indices[i]-7:indices[i]+8]+att_weight[:,indices[i]-7:indices[i]+8].clone()*score[i]
|
| 36 |
+
att_weight = F.softmax(att_weight, dim=1) # TxM
|
| 37 |
+
# print(att_weight.shape)
|
| 38 |
+
# print(period_score.shape)
|
| 39 |
+
# ReLU based shrinkage, hard shrinkage for positive value
|
| 40 |
+
if(self.shrink_thres>0):
|
| 41 |
+
att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres)
|
| 42 |
+
# att_weight = F.softshrink(att_weight, lambd=self.shrink_thres)
|
| 43 |
+
# normalize???
|
| 44 |
+
att_weight = F.normalize(att_weight, p=1, dim=1)
|
| 45 |
+
# att_weight = F.softmax(att_weight, dim=1)
|
| 46 |
+
# att_weight = self.hard_sparse_shrink_opt(att_weight)
|
| 47 |
+
|
| 48 |
+
mem_trans = self.weight.permute(1, 0) # Mem^T, MxC
|
| 49 |
+
output = F.linear(att_weight, mem_trans) # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC
|
| 50 |
+
return {'output': output, 'att': att_weight} # output, att_weight
|
| 51 |
+
|
| 52 |
+
def extra_repr(self):
|
| 53 |
+
return 'mem_dim={}, fea_dim={}'.format(
|
| 54 |
+
self.mem_dim, self.fea_dim is not None
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW
|
| 59 |
+
class MemModule(nn.Module):
|
| 60 |
+
def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025, device='cuda'):
|
| 61 |
+
super(MemModule, self).__init__()
|
| 62 |
+
self.mem_dim = mem_dim
|
| 63 |
+
self.fea_dim = fea_dim
|
| 64 |
+
self.shrink_thres = shrink_thres
|
| 65 |
+
self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)
|
| 66 |
+
|
| 67 |
+
def forward(self, input, period_score):
|
| 68 |
+
s = input.data.shape
|
| 69 |
+
l = len(s)# 5
|
| 70 |
+
if l == 3:
|
| 71 |
+
x = input.permute(0, 2, 1)
|
| 72 |
+
elif l == 4:
|
| 73 |
+
x = input.permute(0, 2, 3, 1)
|
| 74 |
+
elif l == 5:
|
| 75 |
+
x = input.permute(0, 2, 3, 4, 1)
|
| 76 |
+
else:
|
| 77 |
+
x = []
|
| 78 |
+
print('wrong feature map size')
|
| 79 |
+
x = x.contiguous()
|
| 80 |
+
x = x.view(-1, s[1])
|
| 81 |
+
#
|
| 82 |
+
y_and = self.memory(x,period_score)
|
| 83 |
+
#
|
| 84 |
+
y = y_and['output']
|
| 85 |
+
att = y_and['att']
|
| 86 |
+
|
| 87 |
+
if l == 3:
|
| 88 |
+
y = y.view(s[0], s[2], s[1])
|
| 89 |
+
y = y.permute(0, 2, 1)
|
| 90 |
+
att = att.view(s[0], s[2], self.mem_dim)
|
| 91 |
+
att = att.permute(0, 2, 1)
|
| 92 |
+
elif l == 4:
|
| 93 |
+
y = y.view(s[0], s[2], s[3], s[1])
|
| 94 |
+
y = y.permute(0, 3, 1, 2)
|
| 95 |
+
att = att.view(s[0], s[2], s[3], self.mem_dim)
|
| 96 |
+
att = att.permute(0, 3, 1, 2)
|
| 97 |
+
elif l == 5:
|
| 98 |
+
y = y.view(s[0], s[2], s[3], s[4], s[1])
|
| 99 |
+
y = y.permute(0, 4, 1, 2, 3)
|
| 100 |
+
att = att.view(s[0], s[2], s[3], s[4], self.mem_dim)
|
| 101 |
+
att = att.permute(0, 4, 1, 2, 3)
|
| 102 |
+
else:
|
| 103 |
+
y = x
|
| 104 |
+
att = att
|
| 105 |
+
print('wrong feature map size')
|
| 106 |
+
return {'output': y, 'att': att}
|
| 107 |
+
|
| 108 |
+
# relu based hard shrinkage function, only works for positive values
|
| 109 |
+
def hard_shrink_relu(input, lambd=0, epsilon=1e-12):
|
| 110 |
+
output = (F.relu(input-lambd) * input) / (torch.abs(input - lambd) + epsilon)
|
| 111 |
+
return output
|
| 112 |
+
|
IPAD/model/pseudoanomaly_utils.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
import copy
|
| 7 |
+
|
| 8 |
+
def create_pseudoanomaly_cifar_smooth(img, cifar_img, max_size, h, w, dataset, max_move=0):
|
| 9 |
+
assert 0 <= max_size <= 1
|
| 10 |
+
|
| 11 |
+
pil_img = transforms.ToPILImage()(cifar_img)
|
| 12 |
+
pil_img = transforms.Grayscale(num_output_channels=1)(pil_img)
|
| 13 |
+
cifar_img = transforms.ToTensor()(pil_img)
|
| 14 |
+
|
| 15 |
+
cifar_img = transforms.Normalize(mean=[0.5], std=[0.5])(cifar_img)
|
| 16 |
+
|
| 17 |
+
cifar_patch = F.interpolate(cifar_img.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True)
|
| 18 |
+
|
| 19 |
+
x_mu, y_mu = random.randint(0, w), random.randint(0, h)
|
| 20 |
+
x_sigma = max(10, int(np.random.uniform(high=max_size) * w))
|
| 21 |
+
y_sigma = max(10, int(np.random.uniform(high=max_size) * h))
|
| 22 |
+
if max_move == 0:
|
| 23 |
+
mask = torch.tensor(_get_gaussian_mask(x_mu, y_mu, x_sigma, y_sigma, h, w)).to(img.device).float()
|
| 24 |
+
img = mask * cifar_patch.to(img.device) + (1-mask) * img
|
| 25 |
+
else:
|
| 26 |
+
mask = []
|
| 27 |
+
for frame_idx in range(img.size(1)):
|
| 28 |
+
delta_x = np.random.randint(-max_move, max_move + 1)
|
| 29 |
+
delta_y = np.random.randint(-max_move, max_move + 1)
|
| 30 |
+
mask_ = torch.tensor(_get_gaussian_mask(x_mu, y_mu, x_sigma, y_sigma, h, w)).to(img.device).float()
|
| 31 |
+
|
| 32 |
+
img[:, frame_idx] = mask_ * cifar_patch.to(img.device) + (1-mask_) * img[:, frame_idx]
|
| 33 |
+
mask.append(mask_)
|
| 34 |
+
|
| 35 |
+
x_mu = min(max(0, x_mu + delta_x), w)
|
| 36 |
+
y_mu = min(max(0, y_mu + delta_y), h)
|
| 37 |
+
|
| 38 |
+
mask = torch.stack(mask, dim=0)
|
| 39 |
+
|
| 40 |
+
return img, mask
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_pseudoanomaly_cifar_smoothborder(img, cifar_img, max_size, h, w, dataset, max_move=0):
|
| 44 |
+
assert 0 <= max_size <= 1
|
| 45 |
+
|
| 46 |
+
pil_img = transforms.ToPILImage()(cifar_img)
|
| 47 |
+
pil_img = transforms.Grayscale(num_output_channels=1)(pil_img)
|
| 48 |
+
cifar_img = transforms.ToTensor()(pil_img)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
cifar_img = transforms.Normalize(mean=[0.5], std=[0.5])(cifar_img)
|
| 52 |
+
|
| 53 |
+
cifar_patch = F.interpolate(cifar_img.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True)
|
| 54 |
+
|
| 55 |
+
cx, cy = np.random.randint(w), np.random.randint(h)
|
| 56 |
+
|
| 57 |
+
cut_w= max(10, int(np.random.uniform(high=max_size) * w))
|
| 58 |
+
cut_h = max(10, int(np.random.uniform(high=max_size) * h))
|
| 59 |
+
if max_move == 0:
|
| 60 |
+
mask = torch.tensor(_get_smoothborder_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
|
| 61 |
+
img = mask * cifar_patch.to(img.device) + (1-mask) * img
|
| 62 |
+
|
| 63 |
+
else:
|
| 64 |
+
mask = []
|
| 65 |
+
for frame_idx in range(img.size(1)):
|
| 66 |
+
delta_x = np.random.randint(-max_move, max_move + 1)
|
| 67 |
+
delta_y = np.random.randint(-max_move, max_move + 1)
|
| 68 |
+
mask_ = torch.tensor(_get_smoothborder_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
|
| 69 |
+
|
| 70 |
+
img[:, frame_idx] = mask_ * cifar_patch.to(img.device) + (1 - mask_) * img[:, frame_idx]
|
| 71 |
+
mask.append(mask_)
|
| 72 |
+
|
| 73 |
+
cx = min(max(0, cx + delta_x), w)
|
| 74 |
+
cy = min(max(0, cy + delta_y), h)
|
| 75 |
+
|
| 76 |
+
mask = torch.stack(mask, dim=0)
|
| 77 |
+
|
| 78 |
+
return img, mask
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def create_pseudoanomaly_cifar_cutmix(img, cifar_img, max_size, h, w, dataset, max_move=0):
|
| 83 |
+
assert 0 <= max_size <= 1
|
| 84 |
+
|
| 85 |
+
pil_img = transforms.ToPILImage()(cifar_img)
|
| 86 |
+
pil_img = transforms.Grayscale(num_output_channels=1)(pil_img)
|
| 87 |
+
cifar_img = transforms.ToTensor()(pil_img)
|
| 88 |
+
|
| 89 |
+
cifar_img = transforms.Normalize(mean=[0.5], std=[0.5])(cifar_img)
|
| 90 |
+
|
| 91 |
+
cifar_patch = F.interpolate(cifar_img.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True)
|
| 92 |
+
|
| 93 |
+
cx, cy = np.random.randint(w), np.random.randint(h)
|
| 94 |
+
|
| 95 |
+
cut_w= max(10, int(np.random.uniform(high=max_size) * w))
|
| 96 |
+
cut_h = max(10, int(np.random.uniform(high=max_size) * h))
|
| 97 |
+
if max_move == 0:
|
| 98 |
+
mask = torch.tensor(_get_cutmix_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
|
| 99 |
+
img = mask * cifar_patch.to(img.device) + (1-mask) * img
|
| 100 |
+
|
| 101 |
+
else:
|
| 102 |
+
mask = []
|
| 103 |
+
for frame_idx in range(img.size(1)):
|
| 104 |
+
delta_x = np.random.randint(-max_move, max_move + 1)
|
| 105 |
+
delta_y = np.random.randint(-max_move, max_move + 1)
|
| 106 |
+
mask_ = torch.tensor(_get_cutmix_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
|
| 107 |
+
|
| 108 |
+
img[:, frame_idx] = mask_ * cifar_patch.to(img.device) + (1 - mask_) * img[:, frame_idx]
|
| 109 |
+
mask.append(mask_)
|
| 110 |
+
|
| 111 |
+
cx = min(max(0, cx + delta_x), w)
|
| 112 |
+
cy = min(max(0, cy + delta_y), h)
|
| 113 |
+
|
| 114 |
+
mask = torch.stack(mask, dim=0)
|
| 115 |
+
|
| 116 |
+
return img, mask
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def create_pseudoanomaly_cifar_mixupcutmix(img, cifar_img, max_size, h, w, dataset, max_move=0):
|
| 121 |
+
assert 0 <= max_size <= 1
|
| 122 |
+
|
| 123 |
+
pil_img = transforms.ToPILImage()(cifar_img)
|
| 124 |
+
pil_img = transforms.Grayscale(num_output_channels=1)(pil_img)
|
| 125 |
+
cifar_img = transforms.ToTensor()(pil_img)
|
| 126 |
+
|
| 127 |
+
cifar_img = transforms.Normalize(mean=[0.5], std=[0.5])(cifar_img)
|
| 128 |
+
|
| 129 |
+
cifar_patch = F.interpolate(cifar_img.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True)
|
| 130 |
+
|
| 131 |
+
cx, cy = np.random.randint(w), np.random.randint(h)
|
| 132 |
+
|
| 133 |
+
cut_w= max(10, int(np.random.uniform(high=max_size) * w))
|
| 134 |
+
cut_h = max(10, int(np.random.uniform(high=max_size) * h))
|
| 135 |
+
if max_move == 0:
|
| 136 |
+
mask = torch.tensor(_get_cutmix_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
|
| 137 |
+
img = mask * 0.5 * cifar_patch.to(img.device) + mask * 0.5 * img + (1-mask) * img
|
| 138 |
+
|
| 139 |
+
else:
|
| 140 |
+
mask = []
|
| 141 |
+
for frame_idx in range(img.size(1)):
|
| 142 |
+
delta_x = np.random.randint(-max_move, max_move + 1)
|
| 143 |
+
delta_y = np.random.randint(-max_move, max_move + 1)
|
| 144 |
+
mask_ = torch.tensor(_get_cutmix_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
|
| 145 |
+
|
| 146 |
+
img[:, frame_idx] = mask_ * 0.5 * cifar_patch.to(img.device) + mask_ * 0.5 * img[:, frame_idx] + (1 - mask_) * img[:, frame_idx]
|
| 147 |
+
mask.append(mask_)
|
| 148 |
+
|
| 149 |
+
cx = min(max(0, cx + delta_x), w)
|
| 150 |
+
cy = min(max(0, cy + delta_y), h)
|
| 151 |
+
|
| 152 |
+
mask = torch.stack(mask, dim=0)
|
| 153 |
+
|
| 154 |
+
return img, mask
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def create_pseudoanomaly_seq_smoothborder(img, seq, max_size, h, w, dataset, max_move=0):
|
| 159 |
+
assert 0 <= max_size <= 1
|
| 160 |
+
|
| 161 |
+
cx, cy = np.random.randint(w), np.random.randint(h)
|
| 162 |
+
|
| 163 |
+
cut_w= max(10, int(np.random.uniform(high=max_size) * w))
|
| 164 |
+
cut_h = max(10, int(np.random.uniform(high=max_size) * h))
|
| 165 |
+
if max_move == 0:
|
| 166 |
+
mask = torch.tensor(_get_smoothborder_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
|
| 167 |
+
img = mask * seq.to(img.device) + (1-mask) * img
|
| 168 |
+
else:
|
| 169 |
+
mask = []
|
| 170 |
+
for frame_idx in range(img.size(1)):
|
| 171 |
+
delta_x = np.random.randint(-max_move, max_move + 1)
|
| 172 |
+
delta_y = np.random.randint(-max_move, max_move + 1)
|
| 173 |
+
mask_ = torch.tensor(_get_smoothborder_mask(cx, cy, cut_h, cut_w, h, w)).to(img.device).float()
|
| 174 |
+
|
| 175 |
+
img[:, frame_idx] = mask_ * seq[:, frame_idx].to(img.device) + (1 - mask_) * img[:, frame_idx]
|
| 176 |
+
mask.append(mask_)
|
| 177 |
+
|
| 178 |
+
cx = min(max(0, cx + delta_x), w)
|
| 179 |
+
cy = min(max(0, cy + delta_y), h)
|
| 180 |
+
|
| 181 |
+
mask = torch.stack(mask, dim=0)
|
| 182 |
+
|
| 183 |
+
return img, mask
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _get_gaussian_mask(x_mu, y_mu, x_sigma, y_sigma, h, w):
|
| 190 |
+
x, y = np.arange(w), np.arange(h)
|
| 191 |
+
|
| 192 |
+
# x_mu, y_mu = random.randint(0, w), random.randint(0, h)
|
| 193 |
+
# x_sigma = max(10, int(np.random.uniform(high=max_size) * w))
|
| 194 |
+
# y_sigma = max(10, int(np.random.uniform(high=max_size) * h))
|
| 195 |
+
|
| 196 |
+
gx = np.exp(-(x - x_mu) ** 2 / (2 * x_sigma ** 2))
|
| 197 |
+
gy = np.exp(-(y - y_mu) ** 2 / (2 * y_sigma ** 2))
|
| 198 |
+
g = np.outer(gx, gy)
|
| 199 |
+
# g /= np.sum(g) # normalize, if you want that
|
| 200 |
+
|
| 201 |
+
# sum_g = np.sum(g)
|
| 202 |
+
# lam = sum_g / (w * h)
|
| 203 |
+
# print(lam)
|
| 204 |
+
|
| 205 |
+
# plt.imshow(g, interpolation="nearest", origin="lower")
|
| 206 |
+
# plt.show()
|
| 207 |
+
# g = np.dstack([g, g, g])
|
| 208 |
+
|
| 209 |
+
return g
|
| 210 |
+
|
| 211 |
+
# a = _get_gaussian_mask(0.5, 256, 256)
|
| 212 |
+
|
| 213 |
+
def _get_smoothborder_mask(cx, cy, Cut_h, Cut_w, h, w):
|
| 214 |
+
lam = np.random.beta(1, 1)
|
| 215 |
+
percentage = 0.1
|
| 216 |
+
cut_rat = np.sqrt(1. - lam)
|
| 217 |
+
|
| 218 |
+
# Cut_w = min(np.int(max_size*w), max(2, np.int(w * cut_rat)))
|
| 219 |
+
# Cut_h = min(np.int(max_size*h), max(2, np.int(h * cut_rat)))
|
| 220 |
+
|
| 221 |
+
# cx, cy = np.random.randint(w), np.random.randint(h)
|
| 222 |
+
|
| 223 |
+
bbx1 = np.clip(cx - Cut_w // 2, 0, w) # top left x
|
| 224 |
+
bby1 = np.clip(cy - Cut_h // 2, 0, h) # top left y
|
| 225 |
+
bbx2 = np.clip(cx + Cut_w // 2, 0, w) # bottom right x
|
| 226 |
+
bby2 = np.clip(cy + Cut_h // 2, 0, h) # bottom right y
|
| 227 |
+
|
| 228 |
+
img = np.zeros((w, h))
|
| 229 |
+
img2, img3 = np.ones_like(img), np.ones_like(img)
|
| 230 |
+
img[bbx1:bbx2, bby1:bby2] = img2[bbx1:bbx2, bby1:bby2]
|
| 231 |
+
|
| 232 |
+
lo = bbx1 - (Cut_w // 2) * percentage # left side: beginning linear from 0
|
| 233 |
+
li = bbx1 # + (Cut_w // 2) * percentage # left side: end of linear at 1
|
| 234 |
+
ri = bbx2 # - (Cut_w // 2) * percentage # right : start linear from 1
|
| 235 |
+
ro = bbx2 + (Cut_w // 2) * percentage # right: end linear at 0
|
| 236 |
+
|
| 237 |
+
to = bby1 - (Cut_h // 2) * percentage # top: start linear from 0
|
| 238 |
+
ti = bby1 # + (Cut_h // 2) * percentage # top: end linear at 1
|
| 239 |
+
bi = bby2 # - (Cut_h // 2) * percentage # bottom: start linear from 1
|
| 240 |
+
bo = bby2 + (Cut_h // 2) * percentage # bottom: end linear at 0
|
| 241 |
+
|
| 242 |
+
# glx = lambda x: ((x - lo) / (li - lo))
|
| 243 |
+
# grx = lambda x: (-(x - ro) / (ro - ri))
|
| 244 |
+
# gtx = lambda x: ((x - to) / (ti - to))
|
| 245 |
+
# gbx = lambda x: (-(x - bo) / (bo - bi))
|
| 246 |
+
|
| 247 |
+
for i in range(w):
|
| 248 |
+
for j in range(h):
|
| 249 |
+
if i < cx:
|
| 250 |
+
img2[j][i] = ((i - lo) / (li - lo)) # linear going up
|
| 251 |
+
else:
|
| 252 |
+
img2[j][i] = (-(i - ro) / (ro - ri)) # linear going down
|
| 253 |
+
if j < cy:
|
| 254 |
+
img3[j][i] = ((j - to) / (ti - to))
|
| 255 |
+
else:
|
| 256 |
+
img3[j][i] = (-(j - bo) / (bo - bi))
|
| 257 |
+
|
| 258 |
+
img2[img2 < 0] = 0
|
| 259 |
+
img2[img2 > 1] = 1
|
| 260 |
+
|
| 261 |
+
img3[img3 < 0] = 0
|
| 262 |
+
img3[img3 > 1] = 1
|
| 263 |
+
|
| 264 |
+
# plt.figure()
|
| 265 |
+
# plt.subplot(131)
|
| 266 |
+
# plt.imshow(img2)
|
| 267 |
+
# # plt.show()
|
| 268 |
+
# plt.subplot(132)
|
| 269 |
+
# plt.imshow(img3)
|
| 270 |
+
# # plt.show()
|
| 271 |
+
img4 = np.multiply(img2, img3)
|
| 272 |
+
# sum_img4 = np.sum(img4)
|
| 273 |
+
# lam = sum_img4 / (w * h)
|
| 274 |
+
|
| 275 |
+
# plt.subplot(133)
|
| 276 |
+
# plt.imshow(img4)
|
| 277 |
+
# plt.show()
|
| 278 |
+
return img4 #, lam
|
| 279 |
+
|
| 280 |
+
# a = _get_smoothborder_mask(0.5, 256, 256)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _get_cutmix_mask(cx, cy, Cut_h, Cut_w, h, w):
|
| 284 |
+
lam = np.random.beta(1, 1)
|
| 285 |
+
|
| 286 |
+
bbx1 = np.clip(cx - Cut_w // 2, 0, w) # top left x
|
| 287 |
+
bby1 = np.clip(cy - Cut_h // 2, 0, h) # top left y
|
| 288 |
+
bbx2 = np.clip(bbx1 + Cut_w, 0, w) # bottom right x
|
| 289 |
+
bby2 = np.clip(bby1 + Cut_h, 0, h) # bottom right y
|
| 290 |
+
|
| 291 |
+
img = np.zeros((w, h))
|
| 292 |
+
img2 = np.ones_like(img)
|
| 293 |
+
img[bby1:bby2, bbx1:bbx2] = img2[bby1:bby2, bbx1:bbx2]
|
| 294 |
+
|
| 295 |
+
return img #, lam
|
| 296 |
+
|
| 297 |
+
# a = _get_cutmix_mask(100, 100, 15, 30, 256, 256)
|
| 298 |
+
# a = _get_smoothborder_mask(100, 100, 15, 30, 256, 256)
|
IPAD/model/reconstruction_model.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from functools import reduce
|
| 3 |
+
from operator import mul
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class Reconstruction3DEncoder(nn.Module):
|
| 7 |
+
def __init__(self, chnum_in):
|
| 8 |
+
super(Reconstruction3DEncoder, self).__init__()
|
| 9 |
+
|
| 10 |
+
# Dong Gong's paper code
|
| 11 |
+
self.chnum_in = chnum_in
|
| 12 |
+
feature_num = 128
|
| 13 |
+
feature_num_2 = 96
|
| 14 |
+
feature_num_x2 = 256
|
| 15 |
+
self.encoder = nn.Sequential(
|
| 16 |
+
nn.Conv3d(self.chnum_in, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
|
| 17 |
+
nn.BatchNorm3d(feature_num_2),
|
| 18 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 19 |
+
nn.Conv3d(feature_num_2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
|
| 20 |
+
nn.BatchNorm3d(feature_num),
|
| 21 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 22 |
+
nn.Conv3d(feature_num, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
|
| 23 |
+
nn.BatchNorm3d(feature_num_x2),
|
| 24 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 25 |
+
nn.Conv3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
|
| 26 |
+
nn.BatchNorm3d(feature_num_x2),
|
| 27 |
+
nn.LeakyReLU(0.2, inplace=True)
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
x = self.encoder(x)
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Reconstruction3DDecoder(nn.Module):
|
| 36 |
+
def __init__(self, chnum_in):
|
| 37 |
+
super(Reconstruction3DDecoder, self).__init__()
|
| 38 |
+
|
| 39 |
+
# Dong Gong's paper code + Tanh
|
| 40 |
+
self.chnum_in = chnum_in
|
| 41 |
+
feature_num = 128
|
| 42 |
+
feature_num_2 = 96
|
| 43 |
+
feature_num_x2 = 256
|
| 44 |
+
|
| 45 |
+
self.decoder = nn.Sequential(
|
| 46 |
+
nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
|
| 47 |
+
output_padding=(1, 1, 1)),
|
| 48 |
+
nn.BatchNorm3d(feature_num_x2),
|
| 49 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 50 |
+
nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
|
| 51 |
+
output_padding=(1, 1, 1)),
|
| 52 |
+
nn.BatchNorm3d(feature_num),
|
| 53 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 54 |
+
nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
|
| 55 |
+
output_padding=(1, 1, 1)),
|
| 56 |
+
nn.BatchNorm3d(feature_num_2),
|
| 57 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 58 |
+
nn.ConvTranspose3d(feature_num_2, self.chnum_in, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
|
| 59 |
+
output_padding=(0, 1, 1)),
|
| 60 |
+
nn.Tanh()
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
x = self.decoder(x)
|
| 65 |
+
return x
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class VST3DDecoder(nn.Module):
|
| 69 |
+
def __init__(self, chnum_out):
|
| 70 |
+
super(VST3DDecoder, self).__init__()
|
| 71 |
+
|
| 72 |
+
# Dong Gong's paper code + Tanh
|
| 73 |
+
self.chnum_out = chnum_out
|
| 74 |
+
feature_num = 128
|
| 75 |
+
feature_num_2 = 96
|
| 76 |
+
feature_num_x2 = 256
|
| 77 |
+
feature_num_in = 768
|
| 78 |
+
self.transformer_decoder = nn.Sequential(
|
| 79 |
+
# (4,768,4,8,8)
|
| 80 |
+
nn.ConvTranspose3d(feature_num_in, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
|
| 81 |
+
output_padding=(1, 1, 1)),
|
| 82 |
+
nn.BatchNorm3d(feature_num_x2),
|
| 83 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 84 |
+
# (4,256,4,16,16)
|
| 85 |
+
nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
|
| 86 |
+
output_padding=(1, 1, 1)),
|
| 87 |
+
nn.BatchNorm3d(feature_num_x2),
|
| 88 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 89 |
+
nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
|
| 90 |
+
output_padding=(0, 1, 1)),
|
| 91 |
+
nn.BatchNorm3d(feature_num),
|
| 92 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 93 |
+
nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
|
| 94 |
+
output_padding=(0, 1, 1)),
|
| 95 |
+
nn.BatchNorm3d(feature_num_2),
|
| 96 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 97 |
+
nn.ConvTranspose3d(feature_num_2, self.chnum_out, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
|
| 98 |
+
output_padding=(0, 1, 1)),
|
| 99 |
+
nn.Tanh()
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
x = self.transformer_decoder(x)
|
| 104 |
+
return x
|
IPAD/model/utils.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
import os
|
| 4 |
+
import glob
|
| 5 |
+
import cv2
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
import random
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
rng = np.random.RandomState(2020)
|
| 12 |
+
|
| 13 |
+
def np_load_frame(filename, resize_height, resize_width, grayscale=False):
|
| 14 |
+
grayscale = False
|
| 15 |
+
"""
|
| 16 |
+
Load image path and convert it to numpy.ndarray. Notes that the color channels are BGR and the color space
|
| 17 |
+
is normalized from [0, 255] to [-1, 1].
|
| 18 |
+
|
| 19 |
+
:param filename: the full path of image
|
| 20 |
+
:param resize_height: resized height
|
| 21 |
+
:param resize_width: resized width
|
| 22 |
+
:return: numpy.ndarray
|
| 23 |
+
"""
|
| 24 |
+
if grayscale:
|
| 25 |
+
image_decoded = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
|
| 26 |
+
else:
|
| 27 |
+
image_decoded = cv2.imread(filename)
|
| 28 |
+
image_resized = cv2.resize(image_decoded, (resize_width, resize_height))
|
| 29 |
+
# image_resized = np.copy(image_decoded)
|
| 30 |
+
image_resized = image_resized.astype(dtype=np.float32)
|
| 31 |
+
image_resized = (image_resized / 127.5) - 1.0
|
| 32 |
+
return image_resized
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Reconstruction3DDataLoader(data.Dataset):
|
| 37 |
+
def __init__(self, video_folder, transform, resize_height, resize_width, num_frames=16,
|
| 38 |
+
img_extension='.jpg', dataset='ped2', jump=[2], hold=[2], return_normal_seq=False):
|
| 39 |
+
self.dir = video_folder
|
| 40 |
+
self.transform = transform
|
| 41 |
+
self.videos = OrderedDict()
|
| 42 |
+
self._resize_height = resize_height
|
| 43 |
+
self._resize_width = resize_width
|
| 44 |
+
self._num_frames = num_frames
|
| 45 |
+
|
| 46 |
+
self.extension = img_extension
|
| 47 |
+
self.dataset = dataset
|
| 48 |
+
|
| 49 |
+
self.setup()
|
| 50 |
+
self.samples, self.background_models = self.get_all_samples()
|
| 51 |
+
|
| 52 |
+
self.jump = jump
|
| 53 |
+
self.hold = hold
|
| 54 |
+
self.return_normal_seq = return_normal_seq # for fast and slow moving
|
| 55 |
+
|
| 56 |
+
def setup(self):
|
| 57 |
+
videos = glob.glob(os.path.join(self.dir, '*/'))
|
| 58 |
+
for video in sorted(videos):
|
| 59 |
+
print(video)
|
| 60 |
+
video_name = video.split('/')[-2]
|
| 61 |
+
self.videos[video_name] = {}
|
| 62 |
+
self.videos[video_name]['path'] = video
|
| 63 |
+
self.videos[video_name]['frame'] = glob.glob(os.path.join(video, '*' + self.extension))
|
| 64 |
+
self.videos[video_name]['frame'].sort()
|
| 65 |
+
self.videos[video_name]['length'] = len(self.videos[video_name]['frame'])
|
| 66 |
+
|
| 67 |
+
def get_all_samples(self):
|
| 68 |
+
frames = []
|
| 69 |
+
background_models = []
|
| 70 |
+
videos = glob.glob(os.path.join(self.dir, '*/'))
|
| 71 |
+
for video in sorted(videos):
|
| 72 |
+
video_name = video.split('/')[-2]
|
| 73 |
+
|
| 74 |
+
for i in range(len(self.videos[video_name]['frame']) - self._num_frames + 1):
|
| 75 |
+
frames.append(self.videos[video_name]['frame'][i])
|
| 76 |
+
# background_models.append(bg_filename)
|
| 77 |
+
|
| 78 |
+
return frames, background_models
|
| 79 |
+
|
| 80 |
+
def __getitem__(self, index):
|
| 81 |
+
# index = 8
|
| 82 |
+
video_name = self.samples[index].split('/')[-2]
|
| 83 |
+
if self.dataset == 'shanghai' and 'training' in self.samples[index]:
|
| 84 |
+
frame_name = int(self.samples[index].split('/')[-1].split('.')[-2]) - 1
|
| 85 |
+
else:
|
| 86 |
+
frame_name = int(self.samples[index].split('/')[-1].split('.')[-2])
|
| 87 |
+
|
| 88 |
+
batch = []
|
| 89 |
+
for i in range(self._num_frames):
|
| 90 |
+
image = np_load_frame(self.videos[video_name]['frame'][frame_name + i], self._resize_height,
|
| 91 |
+
self._resize_width, grayscale=True)
|
| 92 |
+
if self.transform is not None:
|
| 93 |
+
batch.append(self.transform(image))
|
| 94 |
+
# batch:len=16 ,batch[0]:torch(3,256,256)
|
| 95 |
+
img = OrderedDict()
|
| 96 |
+
img['batch'] = np.stack(batch, axis=1)
|
| 97 |
+
img['index'] = frame_name*200//len(self.videos[video_name]['frame'])
|
| 98 |
+
# return np.stack(batch, axis=1)
|
| 99 |
+
return img
|
| 100 |
+
|
| 101 |
+
def __len__(self):
|
| 102 |
+
return len(self.samples)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Reconstruction3DDataLoaderJump(Reconstruction3DDataLoader):
|
| 106 |
+
def __getitem__(self, index):
|
| 107 |
+
# index = 8
|
| 108 |
+
video_name = self.samples[index].split('/')[-2]
|
| 109 |
+
if self.dataset == 'shanghai' and 'training' in self.samples[index]: # bcos my shanghai's start from 1
|
| 110 |
+
frame_name = int(self.samples[index].split('/')[-1].split('.')[-2]) - 1
|
| 111 |
+
else:
|
| 112 |
+
frame_name = int(self.samples[index].split('/')[-1].split('.')[-2])
|
| 113 |
+
|
| 114 |
+
batch = []
|
| 115 |
+
normal_batch = []
|
| 116 |
+
jump = random.choice(self.jump)
|
| 117 |
+
|
| 118 |
+
retry = 0
|
| 119 |
+
while len(self.videos[video_name]['frame']) < frame_name + (self._num_frames-1) * jump and retry < 10:
|
| 120 |
+
# reselect the frame_name
|
| 121 |
+
frame_name = np.random.randint(len(self.videos[video_name]['frame']))
|
| 122 |
+
retry += 1
|
| 123 |
+
|
| 124 |
+
for i in range(self._num_frames):
|
| 125 |
+
image = np_load_frame(self.videos[video_name]['frame'][min(frame_name + i*jump, len(self.videos[video_name]['frame'])-1)], self._resize_height,
|
| 126 |
+
self._resize_width, grayscale=True)
|
| 127 |
+
|
| 128 |
+
if self.transform is not None:
|
| 129 |
+
batch.append(self.transform(image))
|
| 130 |
+
|
| 131 |
+
if self.return_normal_seq:
|
| 132 |
+
for i in range(self._num_frames):
|
| 133 |
+
image = np_load_frame(self.videos[video_name]['frame'][min(frame_name + i, len(self.videos[video_name]['frame'])-1)], self._resize_height,
|
| 134 |
+
self._resize_width, grayscale=True)
|
| 135 |
+
|
| 136 |
+
if self.transform is not None:
|
| 137 |
+
normal_batch.append(self.transform(image))
|
| 138 |
+
return np.stack(batch, axis=1), np.stack(normal_batch, axis=1)
|
| 139 |
+
|
| 140 |
+
else:
|
| 141 |
+
return np.stack(batch, axis=1), normal_batch
|
| 142 |
+
|
IPAD/model/video_swin_transformer.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .reconstruction_model import Reconstruction3DEncoder, Reconstruction3DDecoder, VST3DDecoder
|
| 3 |
+
from .VST_block import SwinTransformer3D
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from model import MemModule
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
class VST(torch.nn.Module):
|
| 10 |
+
def __init__(self, mem_dim=2000, shrink_thres=0.0025): # for reconstruction
|
| 11 |
+
super(VST, self).__init__()
|
| 12 |
+
self.reconstruction = True
|
| 13 |
+
# self.chnum_in = chnum_in
|
| 14 |
+
|
| 15 |
+
# self.encoder = Reconstruction3DEncoder(chnum_in=1) # black and white
|
| 16 |
+
# self.decoder = Reconstruction3DDecoder(chnum_in=1) # black and white
|
| 17 |
+
self.transformer_encoder = SwinTransformer3D()
|
| 18 |
+
self.mem_rep = MemModule(mem_dim=mem_dim, fea_dim=768, shrink_thres=shrink_thres)
|
| 19 |
+
self.period = nn.Sequential(
|
| 20 |
+
nn.Conv3d(768, 768, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
|
| 21 |
+
nn.BatchNorm3d(768),
|
| 22 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 23 |
+
# (batch_size,256,4,4,4)
|
| 24 |
+
nn.Flatten(1),
|
| 25 |
+
nn.Linear(768*4*4*4,4096),
|
| 26 |
+
nn.ReLU(),
|
| 27 |
+
nn.Linear(4096,2048),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
nn.Linear(2048,200),
|
| 30 |
+
)
|
| 31 |
+
self.transformer_decoder = VST3DDecoder(chnum_out=3)
|
| 32 |
+
# self.encoder = Reconstruction3DEncoder(chnum_in=3) # RGB
|
| 33 |
+
# self.decoder = Reconstruction3DDecoder(chnum_in=3) # RGB
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
|
| 37 |
+
feature = self.transformer_encoder(x)
|
| 38 |
+
#feature (batch_size,768,4,8,8)
|
| 39 |
+
recon_index = self.period(feature)
|
| 40 |
+
# print(recon_index[0])
|
| 41 |
+
res_mem = self.mem_rep(feature, recon_index)
|
| 42 |
+
feature = res_mem['output']
|
| 43 |
+
att = res_mem['att']
|
| 44 |
+
output = self.transformer_decoder(feature.clone())
|
| 45 |
+
|
| 46 |
+
return {'output': output, 'att': att, 'recon_index': recon_index}
|
| 47 |
+
|
| 48 |
+
|