Staty commited on
Commit
2b21abc
·
verified ·
1 Parent(s): 1a8f2a9

Upload 50 files

Browse files
Files changed (50) hide show
  1. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  2. .idea/misc.xml +7 -0
  3. .idea/modules.xml +8 -0
  4. .idea/upload.iml +12 -0
  5. .idea/workspace.xml +42 -0
  6. model_LARRES.py +229 -0
  7. model_convlstm.py +186 -0
  8. modules.py +66 -0
  9. test2015.h5 +3 -0
  10. test2020.h5 +3 -0
  11. train2015.h5 +3 -0
  12. train2020.h5 +3 -0
  13. train_simvp2.py +85 -0
  14. utilpack/__init__.py +32 -0
  15. utilpack/__pycache__/__init__.cpython-312.pyc +0 -0
  16. utilpack/__pycache__/convlstm_modules.cpython-312.pyc +0 -0
  17. utilpack/__pycache__/e3dlstm_modules.cpython-312.pyc +0 -0
  18. utilpack/__pycache__/mau_modules.cpython-312.pyc +0 -0
  19. utilpack/__pycache__/mim_modules.cpython-312.pyc +0 -0
  20. utilpack/__pycache__/mmvp_modules.cpython-312.pyc +0 -0
  21. utilpack/__pycache__/phydnet_modules.cpython-312.pyc +0 -0
  22. utilpack/__pycache__/predrnn_modules.cpython-312.pyc +0 -0
  23. utilpack/__pycache__/predrnnpp_modules.cpython-312.pyc +0 -0
  24. utilpack/__pycache__/predrnnv2_modules.cpython-312.pyc +0 -0
  25. utilpack/__pycache__/simvp_modules.cpython-312.pyc +0 -0
  26. utilpack/__pycache__/swinlstm_modules.cpython-312.pyc +0 -0
  27. utilpack/convlstm_modules.py +58 -0
  28. utilpack/e3dlstm_modules.py +119 -0
  29. utilpack/layers/__init__.py +10 -0
  30. utilpack/layers/__pycache__/__init__.cpython-312.pyc +0 -0
  31. utilpack/layers/__pycache__/hornet.cpython-312.pyc +0 -0
  32. utilpack/layers/__pycache__/moganet.cpython-312.pyc +0 -0
  33. utilpack/layers/__pycache__/poolformer.cpython-312.pyc +0 -0
  34. utilpack/layers/__pycache__/uniformer.cpython-312.pyc +0 -0
  35. utilpack/layers/__pycache__/van.cpython-312.pyc +0 -0
  36. utilpack/layers/hornet.py +112 -0
  37. utilpack/layers/moganet.py +140 -0
  38. utilpack/layers/poolformer.py +97 -0
  39. utilpack/layers/uniformer.py +156 -0
  40. utilpack/layers/van.py +119 -0
  41. utilpack/mau_modules.py +66 -0
  42. utilpack/mim_modules.py +211 -0
  43. utilpack/mmvp_modules.py +349 -0
  44. utilpack/phydnet_modules.py +463 -0
  45. utilpack/predrnn_modules.py +79 -0
  46. utilpack/predrnnpp_modules.py +169 -0
  47. utilpack/predrnnv2_modules.py +82 -0
  48. utilpack/simvp_modules.py +586 -0
  49. utilpack/swinlstm_modules.py +317 -0
  50. utilpack/wast_modules.py +577 -0
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.12" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/upload.iml" filepath="$PROJECT_DIR$/.idea/upload.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/upload.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/workspace.xml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ChangeListManager">
4
+ <list default="true" id="9591cce3-c276-4022-8bb6-62a293d16241" name="更改" comment="" />
5
+ <option name="SHOW_DIALOG" value="false" />
6
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
7
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
8
+ <option name="LAST_RESOLUTION" value="IGNORE" />
9
+ </component>
10
+ <component name="ProjectColorInfo"><![CDATA[{
11
+ "associatedIndex": 3
12
+ }]]></component>
13
+ <component name="ProjectId" id="2ssyDJvvBdJ2oeAvJVoyJMXumP7" />
14
+ <component name="ProjectViewState">
15
+ <option name="hideEmptyMiddlePackages" value="true" />
16
+ <option name="showLibraryContents" value="true" />
17
+ </component>
18
+ <component name="PropertiesComponent"><![CDATA[{
19
+ "keyToString": {
20
+ "RunOnceActivity.ShowReadmeOnStart": "true",
21
+ "last_opened_file_path": "C:/Users/Administrator/Desktop/upload"
22
+ }
23
+ }]]></component>
24
+ <component name="SharedIndexes">
25
+ <attachedChunks>
26
+ <set>
27
+ <option value="bundled-python-sdk-98f27166c754-ba05f1cad1b1-com.jetbrains.pycharm.community.sharedIndexes.bundled-PC-242.21829.153" />
28
+ </set>
29
+ </attachedChunks>
30
+ </component>
31
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
32
+ <component name="TaskManager">
33
+ <task active="true" id="Default" summary="默认任务">
34
+ <changelist id="9591cce3-c276-4022-8bb6-62a293d16241" name="更改" comment="" />
35
+ <created>1739258456359</created>
36
+ <option name="number" value="Default" />
37
+ <option name="presentableId" value="Default" />
38
+ <updated>1739258456359</updated>
39
+ </task>
40
+ <servers />
41
+ </component>
42
+ </project>
model_LARRES.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from modules import ConvSC, Inception
4
+
5
+ from utilpack import (ConvNeXtSubBlock, ConvMixerSubBlock, GASubBlock, gInception_ST,
6
+ HorNetSubBlock, MLPMixerSubBlock, MogaSubBlock, PoolFormerSubBlock,
7
+ SwinSubBlock, UniformerSubBlock, VANSubBlock, ViTSubBlock, TAUSubBlock)
8
+
9
+ def stride_generator(N, reverse=False):
10
+ strides = [1, 2]*10
11
+ if reverse: return list(reversed(strides[:N]))
12
+ else: return strides[:N]
13
+
14
+ class Encoder(nn.Module):
15
+ def __init__(self,C_in, C_hid, N_S):
16
+ super(Encoder,self).__init__()
17
+ strides = stride_generator(N_S)
18
+ self.enc = nn.Sequential(
19
+ ConvSC(C_in, C_hid, stride=strides[0]),
20
+ *[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]
21
+ )
22
+
23
+ def forward(self,x):# B*4, 3, 128, 128
24
+ enc1 = self.enc[0](x)
25
+ latent = enc1
26
+ for i in range(1,len(self.enc)):
27
+ latent = self.enc[i](latent)
28
+ return latent,enc1
29
+
30
+
31
+ class Decoder(nn.Module):
32
+ def __init__(self,C_hid, C_out, N_S):
33
+ super(Decoder,self).__init__()
34
+ strides = stride_generator(N_S, reverse=True)
35
+ self.dec = nn.Sequential(
36
+ *[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
37
+ ConvSC(2*C_hid, C_hid, stride=strides[-1], transpose=True)
38
+ )
39
+ self.readout = nn.Conv2d(C_hid, C_out, 1)
40
+
41
+ def forward(self, hid, enc1=None):
42
+ for i in range(0,len(self.dec)-1):
43
+ hid = self.dec[i](hid)
44
+ Y = self.dec[-1](torch.cat([hid, enc1], dim=1))
45
+ Y = self.readout(Y)
46
+ return Y
47
+
48
+ class Mid_Xnet(nn.Module):
49
+ def __init__(self, channel_in, channel_hid, N_T, incep_ker = [3,5,7,11], groups=8):
50
+ super(Mid_Xnet, self).__init__()
51
+
52
+ self.N_T = N_T
53
+ enc_layers = [Inception(channel_in, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
54
+ for i in range(1, N_T-1):
55
+ enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
56
+ enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
57
+
58
+ dec_layers = [Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
59
+ for i in range(1, N_T-1):
60
+ dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
61
+ dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_in, incep_ker= incep_ker, groups=groups))
62
+
63
+ self.enc = nn.Sequential(*enc_layers)
64
+ self.dec = nn.Sequential(*dec_layers)
65
+
66
+ def forward(self, x):
67
+ B, T, C, H, W = x.shape
68
+ x = x.reshape(B, T*C, H, W)
69
+
70
+ # encoder
71
+ skips = []
72
+ z = x
73
+ for i in range(self.N_T):
74
+ z = self.enc[i](z)
75
+ if i < self.N_T - 1:
76
+ skips.append(z)
77
+
78
+ # decoder
79
+ z = self.dec[0](z)
80
+ for i in range(1, self.N_T):
81
+ z = self.dec[i](torch.cat([z, skips[-i]], dim=1))
82
+
83
+ y = z.reshape(B, T, C, H, W)
84
+ return y
85
+
86
+ class MetaBlock(nn.Module):
87
+ """The hidden Translator of MetaFormer for SimVP"""
88
+
89
+ def __init__(self, in_channels, out_channels, input_resolution=None, model_type=None,
90
+ mlp_ratio=8., drop=0.0, drop_path=0.0, layer_i=0):
91
+ super(MetaBlock, self).__init__()
92
+ self.in_channels = in_channels
93
+ self.out_channels = out_channels
94
+ model_type = model_type.lower() if model_type is not None else 'gsta'
95
+
96
+ if model_type == 'gsta':
97
+ self.block = GASubBlock(
98
+ in_channels, kernel_size=21, mlp_ratio=mlp_ratio,
99
+ drop=drop, drop_path=drop_path, act_layer=nn.GELU)
100
+ elif model_type == 'convmixer':
101
+ self.block = ConvMixerSubBlock(in_channels, kernel_size=11, activation=nn.GELU)
102
+ elif model_type == 'convnext':
103
+ self.block = ConvNeXtSubBlock(
104
+ in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
105
+ elif model_type == 'hornet':
106
+ self.block = HorNetSubBlock(in_channels, mlp_ratio=mlp_ratio, drop_path=drop_path)
107
+ elif model_type in ['mlp', 'mlpmixer']:
108
+ self.block = MLPMixerSubBlock(
109
+ in_channels, input_resolution, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
110
+ elif model_type in ['moga', 'moganet']:
111
+ self.block = MogaSubBlock(
112
+ in_channels, mlp_ratio=mlp_ratio, drop_rate=drop, drop_path_rate=drop_path)
113
+ elif model_type == 'poolformer':
114
+ self.block = PoolFormerSubBlock(
115
+ in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
116
+ elif model_type == 'swin':
117
+ self.block = SwinSubBlock(
118
+ in_channels, input_resolution, layer_i=layer_i, mlp_ratio=mlp_ratio,
119
+ drop=drop, drop_path=drop_path)
120
+ elif model_type == 'uniformer':
121
+ block_type = 'MHSA' if in_channels == out_channels and layer_i > 0 else 'Conv'
122
+ self.block = UniformerSubBlock(
123
+ in_channels, mlp_ratio=mlp_ratio, drop=drop,
124
+ drop_path=drop_path, block_type=block_type)
125
+ elif model_type == 'van':
126
+ self.block = VANSubBlock(
127
+ in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path, act_layer=nn.GELU)
128
+ elif model_type == 'vit':
129
+ self.block = ViTSubBlock(
130
+ in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
131
+ else:
132
+ assert False and "Invalid model_type in SimVP"
133
+
134
+ if in_channels != out_channels:
135
+ self.reduction = nn.Conv2d(
136
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0)
137
+
138
+ def forward(self, x):
139
+ z = self.block(x)
140
+ return z if self.in_channels == self.out_channels else self.reduction(z)
141
+
142
+ class MidMetaNet(nn.Module):
143
+ """The hidden Translator of MetaFormer for SimVP"""
144
+
145
+ def __init__(self, channel_in, channel_hid, N2,
146
+ input_resolution=None, model_type=None,
147
+ mlp_ratio=4., drop=0.0, drop_path=0.1):
148
+ super(MidMetaNet, self).__init__()
149
+ assert N2 >= 2 and mlp_ratio > 1
150
+ self.N2 = N2
151
+ dpr = [ # stochastic depth decay rule
152
+ x.item() for x in torch.linspace(1e-2, drop_path, self.N2)]
153
+
154
+ # downsample
155
+ enc_layers = [MetaBlock(
156
+ channel_in, channel_hid, input_resolution, model_type,
157
+ mlp_ratio, drop, drop_path=dpr[0], layer_i=0)]
158
+ # middle layers
159
+ for i in range(1, N2-1):
160
+ enc_layers.append(MetaBlock(
161
+ channel_hid, channel_hid, input_resolution, model_type,
162
+ mlp_ratio, drop, drop_path=dpr[i], layer_i=i))
163
+ # upsample
164
+ enc_layers.append(MetaBlock(
165
+ channel_hid, channel_in, input_resolution, model_type,
166
+ mlp_ratio, drop, drop_path=drop_path, layer_i=N2-1))
167
+ self.enc = nn.Sequential(*enc_layers)
168
+
169
+ def forward(self, x):
170
+ B, T, C, H, W = x.shape
171
+ x = x.reshape(B, T*C, H, W)
172
+
173
+ z = x
174
+ for i in range(self.N2):
175
+ z = self.enc[i](z)
176
+
177
+ y = z.reshape(B, T, C, H, W)
178
+ return y
179
+
180
+ class SimVP(nn.Module):
181
+ def __init__(self, hid_S=32, hid_T=256, N_S=2, N_T=8, incep_ker=[3,5,7,11], groups=4):
182
+ super(SimVP, self).__init__()
183
+ T, C, H, W = 36,1,72,72
184
+ self.enc = Encoder(C, hid_S, N_S)
185
+ self.hid = MidMetaNet(T * hid_S, hid_T, N_T,
186
+ input_resolution=(H, W), model_type="vit",
187
+ mlp_ratio=8, drop=0.0, drop_path=0.1)
188
+ self.dec = Decoder(hid_S, C, N_S)
189
+
190
+
191
+ def forward(self, x_raw):
192
+ B, T, C, H, W = x_raw.shape
193
+ x = x_raw.view(B*T, C, H, W)
194
+
195
+ embed, skip = self.enc(x)
196
+ _, C_, H_, W_ = embed.shape
197
+
198
+ z = embed.view(B, T, C_, H_, W_)
199
+ hid = self.hid(z)
200
+ hid = hid.reshape(B*T, C_, H_, W_)
201
+
202
+ Y = self.dec(hid, skip)
203
+ Y = Y.reshape(B, T, C, H, W)
204
+ return Y
205
+
206
+
207
+ class larres(nn.Module):
208
+ def __init__(self, hid_S=32, hid_T=256, N_S=2, N_T=8, incep_ker=[3,5,7,11], groups=4):
209
+ super(larres, self).__init__()
210
+ T, C, H, W = 36,1,72,72
211
+ self.enc = Encoder(C, hid_S, N_S)
212
+ self.hid = Mid_Xnet(T * hid_S, hid_T, N_T, incep_ker, groups)
213
+ self.dec = Decoder(hid_S, C, N_S)
214
+
215
+
216
+ def forward(self, x_raw):
217
+ B, T, C, H, W = x_raw.shape
218
+ x = x_raw.view(B*T, C, H, W)
219
+
220
+ embed, skip = self.enc(x)
221
+ _, C_, H_, W_ = embed.shape
222
+
223
+ z = embed.view(B, T, C_, H_, W_)
224
+ hid = self.hid(z)
225
+ hid = hid.reshape(B*T, C_, H_, W_)
226
+
227
+ Y = self.dec(hid, skip)
228
+ Y = Y.reshape(B, T, C, H, W)
229
+ return Y
model_convlstm.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn, Tensor
4
+ import numpy as np
5
+ import h5py
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from torch.utils.data import Subset
8
+ from sklearn.model_selection import train_test_split
9
+
10
+ #Obtained from: https://holmdk.github.io/2020/04/02/video_prediction.html
11
+ class ConvLSTMCell(nn.Module):
12
+ def __init__(self, input_dim, hidden_dim, kernel_size, bias):
13
+ """
14
+ Initialize ConvLSTM cell.
15
+
16
+ Parameters
17
+ ----------
18
+ input_dim: int
19
+ Number of channels of input tensor.
20
+ hidden_dim: int
21
+ Number of channels of hidden state.
22
+ kernel_size: (int, int)
23
+ Size of the convolutional kernel.
24
+ bias: bool
25
+ Whether or not to add the bias.
26
+ """
27
+
28
+ super().__init__()
29
+
30
+ self.input_dim = input_dim
31
+ self.hidden_dim = hidden_dim
32
+
33
+ self.kernel_size = kernel_size
34
+ self.padding = kernel_size[0] // 2, kernel_size[1] // 2
35
+ self.bias = bias
36
+
37
+ self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
38
+ out_channels=4 * self.hidden_dim,
39
+ kernel_size=self.kernel_size,
40
+ padding=self.padding,
41
+ bias=self.bias)
42
+
43
+ def forward(self, input_tensor, cur_state):
44
+ h_cur, c_cur = cur_state
45
+
46
+ combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
47
+
48
+ combined_conv = self.conv(combined)
49
+ cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
50
+ i = torch.sigmoid(cc_i)
51
+ f = torch.sigmoid(cc_f)
52
+ o = torch.sigmoid(cc_o)
53
+ g = torch.tanh(cc_g)
54
+
55
+ c_next = f * c_cur + i * g
56
+ h_next = o * torch.tanh(c_next)
57
+
58
+ return h_next, c_next
59
+
60
+ def init_hidden(self, batch_size, image_size):
61
+ height, width = image_size
62
+ return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
63
+ torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
64
+
65
+ def process_highdim_array(arr):
66
+ """
67
+ 处理形状为 (1, 60, 1, 71, 73) 的高维数组,将最后两个维度从 (71, 73) 变为 (72, 72)。
68
+
69
+ 参数:
70
+ arr (ndarray): 输入的高维 numpy 数组,假设形状为 (1, 60, 1, 71, 73)。
71
+
72
+ 返回:
73
+ ndarray: 处理后的数组,形状为 (1, 60, 1, 72, 72)。
74
+ """
75
+ # 检查数组的最后两个维度是否为 (71, 73)
76
+ if arr.shape[-2:] != (71, 73):
77
+ raise ValueError("输入数组的最后两个维度必须是 (71, 73)")
78
+
79
+ # 对最后两个维度的 (71, 73) 进行处理
80
+ # 去掉最后一个维度,变成 (71, 72)
81
+ arr_trimmed = arr[..., :-1]
82
+
83
+ # 在倒数第二个维度填充一行 0,变成 (72, 72)
84
+ arr_padded = np.pad(arr_trimmed, ((0, 0), (0, 0), (0, 1), (0, 0)), mode='constant', constant_values=0)
85
+
86
+ return arr_padded
87
+
88
+ class ionexDataset(Dataset):
89
+ def __init__(self, npy_data, nstepsin=36, nstepsout=12, stride=12):
90
+ self.data = npy_data.astype(np.float32)
91
+ self.nstepsin=nstepsin
92
+ self.nstepsout=nstepsout
93
+ self.stride=stride
94
+ self.idx=np.arange(0,len(self.data)-nstepsout-nstepsin+1,stride)
95
+
96
+ def __getitem__(self, index):
97
+ # find the end of this pattern
98
+ i=self.idx[index]
99
+ end_ix = i + self.nstepsin
100
+ # check if we are beyond the sequence
101
+ if end_ix + self.nstepsout> len(self.data):
102
+ return None,None
103
+ # gather input and output parts of the pattern
104
+ seq_x, seq_y = self.data[i:end_ix], self.data[end_ix:end_ix+self.nstepsout]
105
+ return process_highdim_array(seq_x),process_highdim_array(seq_y)
106
+
107
+ def __len__(self):
108
+ return len(self.idx)
109
+
110
+ def split_train_val(self, val_split=0.25):
111
+ train_idx, val_idx = train_test_split(list(range(len(self))), test_size=val_split)
112
+ return Subset(self, train_idx), Subset(self, val_idx)
113
+
114
+ nstepsin=36
115
+ nstepsout=12
116
+ stride=12
117
+ max_epochs=200
118
+ # batch_size=2
119
+
120
+ f = h5py.File('train2015.h5', 'r')
121
+ train_npy=np.array(f['2020'])/10
122
+ f = h5py.File('test2015.h5', 'r')
123
+ test_npy=np.array(f['2015'])/10
124
+
125
+ # f = h5py.File('train2015.h5', 'r')
126
+ # train_npy=np.array(f['2020'])/10
127
+ # f1=h5py.File('c1pg2015.h5', 'r')
128
+ # f = h5py.File('test2015.h5', 'r')
129
+ # test_npy=np.array(f['2015'])/10-np.array(f1['2015'])/10
130
+
131
+ # f = h5py.File('train2020.h5', 'r')
132
+ # train_npy=np.array(f['2020'])/10
133
+ # f = h5py.File('test2020.h5', 'r')
134
+ # test_npy=np.array(f['2020'])/10
135
+
136
+ # f = h5py.File('train2020.h5', 'r')
137
+ # train_npy=np.array(f['2020'])/10
138
+ # f1=h5py.File('c1pg2020.h5', 'r')
139
+ # f = h5py.File('test2020.h5', 'r')
140
+ # test_npy=np.array(f['2020'])/10-np.array(f1['2020'])/10
141
+
142
+ f.close()
143
+ print("Training data:", train_npy.shape)
144
+ print("Testing data:", test_npy.shape)
145
+
146
+ class EncoderDecoderConvLSTM(nn.Module):
147
+ def __init__(self, nf, in_chan, out_chan, nstepsout=12):
148
+ super().__init__()
149
+ self.nstepsout=nstepsout
150
+ self.encoder_1_convlstm = ConvLSTMCell(input_dim=in_chan, hidden_dim=nf, kernel_size=(3, 3), bias=True)
151
+ self.encoder_2_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
152
+ self.encoder_3_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
153
+ self.decoder_1_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
154
+ self.decoder_2_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
155
+ self.decoder_3_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
156
+ self.conv2d = nn.Conv2d(in_channels=nf, out_channels=1, kernel_size=(1,1))
157
+
158
+ def forward(self, x, future_seq=0, hidden_state=None):
159
+ b, seq_len, _, h, w = x.size()
160
+
161
+ # encoder
162
+ # initialize hidden states
163
+ h1, c1 = self.encoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
164
+ h2, c2 = self.encoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
165
+ h3, c3 = self.decoder_3_convlstm.init_hidden(batch_size=b, image_size=(h, w))
166
+ b, seq_len, _, h, w = x.size()
167
+
168
+ for t in range(seq_len):
169
+ h1, c1 = self.encoder_1_convlstm(input_tensor=x[:, t, :, :], cur_state=[h1, c1])
170
+ h2, c2 = self.encoder_2_convlstm(input_tensor=h1, cur_state=[h2, c2])
171
+ h3, c3 = self.encoder_3_convlstm(input_tensor=h2, cur_state=[h3, c3])
172
+
173
+ # decoder
174
+ # initialize hidden states
175
+ h4, c4 = h1, c1 #self.decoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
176
+ h5, c5 = h2, c2 #self.decoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
177
+ h6, c6 = h3, c3 #self.decoder_3_convlstm.init_hidden(batch_size=b, image_size=(h, w))
178
+
179
+ outputs=[]
180
+ for t in range(self.nstepsout):
181
+ h4, c4 = self.decoder_1_convlstm(input_tensor=h3, cur_state=[h4, c4]) #note that h3 is not updated during prediction
182
+ h5, c5 = self.decoder_2_convlstm(input_tensor=h4, cur_state=[h5, c5])
183
+ h6, c6 = self.decoder_3_convlstm(input_tensor=h5, cur_state=[h6, c6])
184
+ outputs.append(self.conv2d(h4))
185
+ outputs = torch.stack(outputs, 1)
186
+ return outputs
modules.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class BasicConv2d(nn.Module):
5
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False):
6
+ super(BasicConv2d, self).__init__()
7
+ self.act_norm=act_norm
8
+ if not transpose:
9
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
10
+ else:
11
+ self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 )
12
+ self.norm = nn.GroupNorm(2, out_channels)
13
+ self.act = nn.LeakyReLU(0.2, inplace=True)
14
+
15
+ def forward(self, x):
16
+ y = self.conv(x)
17
+ if self.act_norm:
18
+ y = self.act(self.norm(y))
19
+ return y
20
+
21
+
22
+ class ConvSC(nn.Module):
23
+ def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):
24
+ super(ConvSC, self).__init__()
25
+ if stride == 1:
26
+ transpose = False
27
+ self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride,
28
+ padding=1, transpose=transpose, act_norm=act_norm)
29
+
30
+ def forward(self, x):
31
+ y = self.conv(x)
32
+ return y
33
+
34
+
35
+ class GroupConv2d(nn.Module):
36
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False):
37
+ super(GroupConv2d, self).__init__()
38
+ self.act_norm = act_norm
39
+ if in_channels % groups != 0:
40
+ groups = 1
41
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups)
42
+ self.norm = nn.GroupNorm(groups,out_channels)
43
+ self.activate = nn.LeakyReLU(0.2, inplace=True)
44
+
45
+ def forward(self, x):
46
+ y = self.conv(x)
47
+ if self.act_norm:
48
+ y = self.activate(self.norm(y))
49
+ return y
50
+
51
+
52
+ class Inception(nn.Module):
53
+ def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8):
54
+ super(Inception, self).__init__()
55
+ self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
56
+ layers = []
57
+ for ker in incep_ker:
58
+ layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, groups=groups, act_norm=True))
59
+ self.layers = nn.Sequential(*layers)
60
+
61
+ def forward(self, x):
62
+ x = self.conv1(x)
63
+ y = 0
64
+ for layer in self.layers:
65
+ y += layer(x)
66
+ return y
test2015.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8f012ef4a1fc40d1c993cea1eff972ea56cbda86fd3a433431ea71d82259e09
3
+ size 181614368
test2020.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0a1c3c6d19a81a998ab4381bf189ba0ac7b8c6378008ad7b3d1465ffa20edd1
3
+ size 182111936
train2015.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f17b0b8430647d2cc21a1bc2af719e3a2370bed8d93ac70626fcc08fcc2e546c
3
+ size 545336576
train2020.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f39605870ce1fc5930277d5d225a29b3aaaff8fc53c4a00c9c6149740d91ebf
3
+ size 544839008
train_simvp2.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn, Tensor
5
+ import numpy as np
6
+ import h5py
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from torch.utils.data import Subset
9
+ from sklearn.model_selection import train_test_split
10
+ import torch.optim as optim
11
+
12
+ from model_convlstm import ionexDataset, train_npy, nstepsin, nstepsout, stride, EncoderDecoderConvLSTM, max_epochs
13
+ from model_LARRES import larres
14
+
15
+ # ionexData = ionexDataset(train_npy, nstepsin=nstepsin, nstepsout=nstepsout, stride=stride)
16
+ # train_data, val_data = ionexData.split_train_val(val_split=0.2)
17
+ #
18
+ # train_loader = DataLoader(train_data, batch_size=16, num_workers=0)
19
+ # val_loader = DataLoader(val_data, batch_size=16, num_workers=0)
20
+
21
+ ionexData = ionexDataset(train_npy, nstepsin=nstepsin, nstepsout=nstepsout, stride=stride)
22
+ train_data, val_data = ionexData.split_train_val(val_split=0.2)
23
+
24
+ train_loader = DataLoader(train_data, batch_size=16, num_workers=0)
25
+ val_loader = DataLoader(val_data, batch_size=16, num_workers=0)
26
+
27
+ for X, y in train_loader:
28
+ print(f"Shape of X: {X.shape} {X.dtype} [N, C, H, W]")
29
+ print(f"Shape of Y: {y.shape} {y.dtype}")
30
+ break
31
+ print(f"Training samples: {len(train_loader.dataset)}")
32
+ # print(f"Validation samples: {len(val_loader.dataset)}")
33
+
34
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+
36
+ model=larres().to(device)
37
+ # model.load_state_dict(torch.load("best_model.pth"))
38
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
39
+ criterion = nn.L1Loss()
40
+
41
+ # 训练和验证
42
+
43
+ best_val_loss = float('inf')
44
+ # num_epochs = 50
45
+
46
+ for epoch in range(max_epochs):
47
+ # 训练阶段
48
+ model.train()
49
+ all_loss = 0
50
+ for batch_idx, (data, target) in enumerate(train_loader):
51
+ data, target = data.to(device), target.to(device) # 将数据和目标迁移到 CUDA
52
+ optimizer.zero_grad()
53
+ output = model(data)
54
+ target_last = target - data[:, 24:36, :, :, :]
55
+ # loss = criterion(output, target_last) # 使用 L1 损失
56
+ loss = criterion(output[:,:12,:,:71,:], target_last[:,:12,:,:71,:]) # 使用 L1 损失
57
+ print(loss)
58
+ all_loss+=loss
59
+ loss.backward()
60
+ optimizer.step()
61
+
62
+ print(f'Epoch {epoch + 1}/{max_epochs}, Train Loss: {all_loss.item():.4f}')
63
+
64
+ # 验证阶段
65
+ model.eval()
66
+ val_loss = 0.0
67
+ with torch.no_grad():
68
+ for data, target in val_loader:
69
+ data, target = data.to(device), target.to(device) # 将数据和目标迁移到 CUDA
70
+ output = model(data)
71
+ target_last = target - data[:, 24:36, :, :, :]
72
+ # loss = criterion(output, target_last) # 使用 L1 损失
73
+ loss = criterion(output[:, :12, :, :71, :], target_last[:, :12, :, :71, :]) # 使用 L1 损失
74
+ val_loss += loss.item()
75
+
76
+ val_loss /= len(val_loader)
77
+ print(f'Epoch {epoch + 1}/{max_epochs}, Val Loss: {val_loss:.4f}')
78
+
79
+ # 保存最佳模型
80
+ if val_loss < best_val_loss:
81
+ best_val_loss = val_loss
82
+ torch.save(model.state_dict(), 'best_model.pth')
83
+ print('Best model saved!')
84
+
85
+ print('Training completed.')
utilpack/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) CAIRI AI Lab. All rights reserved
2
+
3
+ from .convlstm_modules import ConvLSTMCell
4
+ from .e3dlstm_modules import Eidetic3DLSTMCell, tf_Conv3d
5
+ from .mim_modules import MIMBlock, MIMN
6
+ from .mau_modules import MAUCell
7
+ from .phydnet_modules import PhyCell, PhyD_ConvLSTM, PhyD_EncoderRNN, K2M
8
+ from .predrnn_modules import SpatioTemporalLSTMCell
9
+ from .predrnnpp_modules import CausalLSTMCell, GHU
10
+ from .predrnnv2_modules import SpatioTemporalLSTMCellv2
11
+ from .simvp_modules import (BasicConv2d, ConvSC, GroupConv2d,
12
+ ConvNeXtSubBlock, ConvMixerSubBlock, GASubBlock, gInception_ST,
13
+ HorNetSubBlock, MLPMixerSubBlock, MogaSubBlock, PoolFormerSubBlock,
14
+ SwinSubBlock, UniformerSubBlock, VANSubBlock, ViTSubBlock, TAUSubBlock)
15
+ from .mmvp_modules import (ResBlock, RRDB, ResidualDenseBlock_4C, Up, Conv3D, ConvLayer,
16
+ MatrixPredictor3DConv, SimpleMatrixPredictor3DConv_direct, PredictModel)
17
+ from .swinlstm_modules import UpSample, DownSample, STconvert
18
+
19
+ __all__ = [
20
+ 'ConvLSTMCell', 'CausalLSTMCell', 'GHU', 'SpatioTemporalLSTMCell', 'SpatioTemporalLSTMCellv2',
21
+ 'MIMBlock', 'MIMN', 'Eidetic3DLSTMCell', 'tf_Conv3d',
22
+ 'PhyCell', 'PhyD_ConvLSTM', 'PhyD_EncoderRNN', 'K2M', 'MAUCell',
23
+ 'BasicConv2d', 'ConvSC', 'GroupConv2d',
24
+ 'ConvNeXtSubBlock', 'ConvMixerSubBlock', 'GASubBlock', 'gInception_ST',
25
+ 'HorNetSubBlock', 'MLPMixerSubBlock', 'MogaSubBlock', 'PoolFormerSubBlock',
26
+ 'SwinSubBlock', 'UniformerSubBlock', 'VANSubBlock', 'ViTSubBlock', 'TAUSubBlock',
27
+ 'ResBlock', 'RRDB', 'ResidualDenseBlock_4C', 'Up', 'Conv3D', 'ConvLayer',
28
+ 'MatrixPredictor3DConv', 'SimpleMatrixPredictor3DConv_direct', 'PredictModel',
29
+ 'UpSample', 'DownSample', 'STconvert'
30
+
31
+
32
+ ]
utilpack/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.73 kB). View file
 
utilpack/__pycache__/convlstm_modules.cpython-312.pyc ADDED
Binary file (3.37 kB). View file
 
utilpack/__pycache__/e3dlstm_modules.cpython-312.pyc ADDED
Binary file (6.96 kB). View file
 
utilpack/__pycache__/mau_modules.cpython-312.pyc ADDED
Binary file (4.09 kB). View file
 
utilpack/__pycache__/mim_modules.cpython-312.pyc ADDED
Binary file (10.8 kB). View file
 
utilpack/__pycache__/mmvp_modules.cpython-312.pyc ADDED
Binary file (25.7 kB). View file
 
utilpack/__pycache__/phydnet_modules.cpython-312.pyc ADDED
Binary file (27.2 kB). View file
 
utilpack/__pycache__/predrnn_modules.cpython-312.pyc ADDED
Binary file (4.65 kB). View file
 
utilpack/__pycache__/predrnnpp_modules.cpython-312.pyc ADDED
Binary file (9.27 kB). View file
 
utilpack/__pycache__/predrnnv2_modules.cpython-312.pyc ADDED
Binary file (4.7 kB). View file
 
utilpack/__pycache__/simvp_modules.cpython-312.pyc ADDED
Binary file (37.5 kB). View file
 
utilpack/__pycache__/swinlstm_modules.cpython-312.pyc ADDED
Binary file (16.6 kB). View file
 
utilpack/convlstm_modules.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ConvLSTMCell(nn.Module):
6
+
7
+ def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
8
+ super(ConvLSTMCell, self).__init__()
9
+
10
+ self.num_hidden = num_hidden
11
+ self.padding = filter_size // 2
12
+ self._forget_bias = 1.0
13
+ if layer_norm:
14
+ self.conv_x = nn.Sequential(
15
+ nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
16
+ stride=stride, padding=self.padding, bias=False),
17
+ nn.LayerNorm([num_hidden * 4, height, width])
18
+ )
19
+ self.conv_h = nn.Sequential(
20
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
21
+ stride=stride, padding=self.padding, bias=False),
22
+ nn.LayerNorm([num_hidden * 4, height, width])
23
+ )
24
+ self.conv_o = nn.Sequential(
25
+ nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
26
+ stride=stride, padding=self.padding, bias=False),
27
+ nn.LayerNorm([num_hidden, height, width])
28
+ )
29
+ else:
30
+ self.conv_x = nn.Sequential(
31
+ nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
32
+ stride=stride, padding=self.padding, bias=False),
33
+ )
34
+ self.conv_h = nn.Sequential(
35
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
36
+ stride=stride, padding=self.padding, bias=False),
37
+ )
38
+ self.conv_o = nn.Sequential(
39
+ nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
40
+ stride=stride, padding=self.padding, bias=False),
41
+ )
42
+ self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
43
+ stride=1, padding=0, bias=False)
44
+
45
+ def forward(self, x_t, h_t, c_t):
46
+ x_concat = self.conv_x(x_t)
47
+ h_concat = self.conv_h(h_t)
48
+ i_x, f_x, g_x, o_x = torch.split(x_concat, self.num_hidden, dim=1)
49
+ i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
50
+
51
+ i_t = torch.sigmoid(i_x + i_h)
52
+ f_t = torch.sigmoid(f_x + f_h)
53
+ g_t = torch.tanh(g_x + g_h)
54
+
55
+ c_new = f_t * c_t + i_t * g_t
56
+ o_t = torch.sigmoid(o_x + o_h)
57
+ h_new = o_t * torch.tanh(c_new)
58
+ return h_new, c_new
utilpack/e3dlstm_modules.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class tf_Conv3d(nn.Module):
7
+
8
+ def __init__(self, in_channels, out_channels, *vargs, **kwargs):
9
+ super(tf_Conv3d, self).__init__()
10
+ self.conv3d = nn.Conv3d(in_channels, out_channels, *vargs, **kwargs)
11
+
12
+ def forward(self, input):
13
+ return F.interpolate(self.conv3d(input), size=input.shape[-3:], mode="nearest")
14
+
15
+
16
+ class Eidetic3DLSTMCell(nn.Module):
17
+
18
+ def __init__(self, in_channel, num_hidden, window_length,
19
+ height, width, filter_size, stride, layer_norm):
20
+ super(Eidetic3DLSTMCell, self).__init__()
21
+
22
+ self._norm_c_t = nn.LayerNorm([num_hidden, window_length, height, width])
23
+ self.num_hidden = num_hidden
24
+ self.padding = (0, filter_size[1] // 2, filter_size[2] // 2)
25
+ self._forget_bias = 1.0
26
+ if layer_norm:
27
+ self.conv_x = nn.Sequential(
28
+ tf_Conv3d(in_channel, num_hidden * 7, kernel_size=filter_size,
29
+ stride=stride, padding=self.padding, bias=False),
30
+ nn.LayerNorm([num_hidden * 7, window_length, height, width])
31
+ )
32
+ self.conv_h = nn.Sequential(
33
+ tf_Conv3d(num_hidden, num_hidden * 4, kernel_size=filter_size,
34
+ stride=stride, padding=self.padding, bias=False),
35
+ nn.LayerNorm([num_hidden * 4, window_length, height, width])
36
+ )
37
+ self.conv_gm = nn.Sequential(
38
+ tf_Conv3d(num_hidden, num_hidden * 4, kernel_size=filter_size,
39
+ stride=stride, padding=self.padding, bias=False),
40
+ nn.LayerNorm([num_hidden * 4, window_length, height, width])
41
+ )
42
+ self.conv_new_cell = nn.Sequential(
43
+ tf_Conv3d(num_hidden, num_hidden, kernel_size=filter_size,
44
+ stride=stride, padding=self.padding, bias=False),
45
+ nn.LayerNorm([num_hidden, window_length, height, width])
46
+ )
47
+ self.conv_new_gm = nn.Sequential(
48
+ tf_Conv3d(num_hidden, num_hidden, kernel_size=filter_size,
49
+ stride=stride, padding=self.padding, bias=False),
50
+ nn.LayerNorm([num_hidden, window_length, height, width])
51
+ )
52
+ else:
53
+ self.conv_x = nn.Sequential(
54
+ tf_Conv3d(in_channel, num_hidden * 7, kernel_size=filter_size,
55
+ stride=stride, padding=self.padding, bias=False),
56
+ )
57
+ self.conv_h = nn.Sequential(
58
+ tf_Conv3d(num_hidden, num_hidden * 4, kernel_size=filter_size,
59
+ stride=stride, padding=self.padding, bias=False),
60
+ )
61
+ self.conv_gm = nn.Sequential(
62
+ tf_Conv3d(num_hidden, num_hidden * 4, kernel_size=filter_size,
63
+ stride=stride, padding=self.padding, bias=False),
64
+ )
65
+ self.conv_new_cell = nn.Sequential(
66
+ tf_Conv3d(num_hidden, num_hidden, kernel_size=filter_size,
67
+ stride=stride, padding=self.padding, bias=False),
68
+ )
69
+ self.conv_new_gm = nn.Sequential(
70
+ tf_Conv3d(num_hidden, num_hidden, kernel_size=filter_size,
71
+ stride=stride, padding=self.padding, bias=False),
72
+ )
73
+ self.conv_last = tf_Conv3d(num_hidden * 2, num_hidden, kernel_size=1,
74
+ stride=1, padding=0, bias=False)
75
+
76
+ def _attn(self, in_query, in_keys, in_values):
77
+ batch, num_channels, _, width, height = in_query.shape
78
+ query = in_query.reshape(batch, -1, num_channels)
79
+ keys = in_keys.reshape(batch, -1, num_channels)
80
+ values = in_values.reshape(batch, -1, num_channels)
81
+ attn = torch.einsum('bxc,byc->bxy', query, keys)
82
+ attn = torch.softmax(attn, dim=2)
83
+ attn = torch.einsum("bxy,byc->bxc", attn, values)
84
+ return attn.reshape(batch, num_channels, -1, width, height)
85
+
86
+ def forward(self, x_t, h_t, c_t, global_memory, eidetic_cell):
87
+ h_concat = self.conv_h(h_t)
88
+ i_h, g_h, r_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
89
+
90
+ x_concat = self.conv_x(x_t)
91
+ i_x, g_x, r_x, o_x, temp_i_x, temp_g_x, temp_f_x = \
92
+ torch.split(x_concat, self.num_hidden, dim=1)
93
+
94
+ i_t = torch.sigmoid(i_x + i_h)
95
+ r_t = torch.sigmoid(r_x + r_h)
96
+ g_t = torch.tanh(g_x + g_h)
97
+
98
+ new_cell = c_t + self._attn(r_t, eidetic_cell, eidetic_cell)
99
+ new_cell = self._norm_c_t(new_cell) + i_t * g_t
100
+
101
+ new_global_memory = self.conv_gm(global_memory)
102
+ i_m, f_m, g_m, m_m = torch.split(new_global_memory, self.num_hidden, dim=1)
103
+
104
+ temp_i_t = torch.sigmoid(temp_i_x + i_m)
105
+ temp_f_t = torch.sigmoid(temp_f_x + f_m + self._forget_bias)
106
+ temp_g_t = torch.tanh(temp_g_x + g_m)
107
+ new_global_memory = temp_f_t * torch.tanh(m_m) + temp_i_t * temp_g_t
108
+
109
+ o_c = self.conv_new_cell(new_cell)
110
+ o_m = self.conv_new_gm(new_global_memory)
111
+
112
+ output_gate = torch.tanh(o_x + o_h + o_c + o_m)
113
+
114
+ memory = torch.cat((new_cell, new_global_memory), 1)
115
+ memory = self.conv_last(memory)
116
+
117
+ output = torch.tanh(memory) * torch.sigmoid(output_gate)
118
+
119
+ return output, new_cell, global_memory
utilpack/layers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .hornet import HorBlock
2
+ from .moganet import ChannelAggregationFFN, MultiOrderGatedAggregation, MultiOrderDWConv
3
+ from .poolformer import PoolFormerBlock
4
+ from .uniformer import CBlock, SABlock
5
+ from .van import DWConv, MixMlp, VANBlock
6
+
7
+ __all__ = [
8
+ 'HorBlock', 'ChannelAggregationFFN', 'MultiOrderGatedAggregation', 'MultiOrderDWConv',
9
+ 'PoolFormerBlock', 'CBlock', 'SABlock', 'DWConv', 'MixMlp', 'VANBlock',
10
+ ]
utilpack/layers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (582 Bytes). View file
 
utilpack/layers/__pycache__/hornet.cpython-312.pyc ADDED
Binary file (7.78 kB). View file
 
utilpack/layers/__pycache__/moganet.cpython-312.pyc ADDED
Binary file (8.2 kB). View file
 
utilpack/layers/__pycache__/poolformer.cpython-312.pyc ADDED
Binary file (6 kB). View file
 
utilpack/layers/__pycache__/uniformer.cpython-312.pyc ADDED
Binary file (11.4 kB). View file
 
utilpack/layers/__pycache__/van.cpython-312.pyc ADDED
Binary file (8.25 kB). View file
 
utilpack/layers/hornet.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # refer to the code from HorNet, Thanks!
2
+ # https://github.com/raoyongming/HorNet
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from timm.layers import DropPath
8
+ import torch.fft
9
+
10
+
11
+ def get_dwconv(dim, kernel, bias):
12
+ return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim)
13
+
14
+
15
+ class gnconv(nn.Module):
16
+ def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
17
+ super().__init__()
18
+ self.order = order
19
+ self.dims = [dim // 2 ** i for i in range(order)]
20
+ self.dims.reverse()
21
+ self.proj_in = nn.Conv2d(dim, 2*dim, 1)
22
+
23
+ if gflayer is None:
24
+ self.dwconv = get_dwconv(sum(self.dims), 7, True)
25
+ else:
26
+ self.dwconv = gflayer(sum(self.dims), h=h, w=w)
27
+
28
+ self.proj_out = nn.Conv2d(dim, dim, 1)
29
+
30
+ self.pws = nn.ModuleList(
31
+ [nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)]
32
+ )
33
+
34
+ self.scale = s
35
+ print('[gnconv]', order, 'order with dims=', self.dims, 'scale=%.4f'%self.scale)
36
+
37
+ def forward(self, x, mask=None, dummy=False):
38
+ fused_x = self.proj_in(x)
39
+ pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
40
+
41
+ dw_abc = self.dwconv(abc) * self.scale
42
+
43
+ dw_list = torch.split(dw_abc, self.dims, dim=1)
44
+ x = pwa * dw_list[0]
45
+
46
+ for i in range(self.order -1):
47
+ x = self.pws[i](x) * dw_list[i+1]
48
+
49
+ x = self.proj_out(x)
50
+
51
+ return x
52
+
53
+ class LayerNorm(nn.Module):
54
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
55
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
56
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
57
+ with shape (batch_size, channels, height, width).
58
+ """
59
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
60
+ super().__init__()
61
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
62
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
63
+ self.eps = eps
64
+ self.data_format = data_format
65
+ if self.data_format not in ["channels_last", "channels_first"]:
66
+ raise NotImplementedError
67
+ self.normalized_shape = (normalized_shape, )
68
+
69
+ def forward(self, x):
70
+ if self.data_format == "channels_last":
71
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
72
+ elif self.data_format == "channels_first":
73
+ u = x.mean(1, keepdim=True)
74
+ s = (x - u).pow(2).mean(1, keepdim=True)
75
+ x = (x - u) / torch.sqrt(s + self.eps)
76
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
77
+ return x
78
+
79
+
80
+ class HorBlock(nn.Module):
81
+ """ HorNet block """
82
+
83
+ def __init__(self, dim, order=4, mlp_ratio=4, drop_path=0., init_value=1e-6, gnconv=gnconv):
84
+ super().__init__()
85
+
86
+ self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first')
87
+ self.gnconv = gnconv(dim, order) # depthwise conv
88
+ self.norm2 = LayerNorm(dim, eps=1e-6)
89
+ self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim)) # pointwise/1x1 convs, implemented with linear layers
90
+ self.act = nn.GELU()
91
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
92
+ self.gamma1 = nn.Parameter(init_value * torch.ones(dim), requires_grad=True)
93
+ self.gamma2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
94
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+
96
+ def forward(self, x):
97
+ B, C, H, W = x.shape
98
+ gamma1 = self.gamma1.view(C, 1, 1)
99
+ x = x + self.drop_path(gamma1 * self.gnconv(self.norm1(x)))
100
+
101
+ input = x
102
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
103
+ x = self.norm2(x)
104
+ x = self.pwconv1(x)
105
+ x = self.act(x)
106
+ x = self.pwconv2(x)
107
+ if self.gamma2 is not None:
108
+ x = self.gamma2 * x
109
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
110
+
111
+ x = input + self.drop_path(x)
112
+ return x
utilpack/layers/moganet.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # refer to the code from MogaNet, Thanks!
2
+ # https://github.com/Westlake-AI/MogaNet/blob/main/models/moganet.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class ChannelAggregationFFN(nn.Module):
10
+ """An implementation of FFN with Channel Aggregation in MogaNet."""
11
+
12
+ def __init__(self, embed_dims, mlp_hidden_dims, kernel_size=3, act_layer=nn.GELU, ffn_drop=0.):
13
+ super(ChannelAggregationFFN, self).__init__()
14
+ self.embed_dims = embed_dims
15
+ self.mlp_hidden_dims = mlp_hidden_dims
16
+
17
+ self.fc1 = nn.Conv2d(
18
+ in_channels=embed_dims, out_channels=self.mlp_hidden_dims, kernel_size=1)
19
+ self.dwconv = nn.Conv2d(
20
+ in_channels=self.mlp_hidden_dims, out_channels=self.mlp_hidden_dims, kernel_size=kernel_size,
21
+ padding=kernel_size // 2, bias=True, groups=self.mlp_hidden_dims)
22
+ self.act = act_layer()
23
+ self.fc2 = nn.Conv2d(
24
+ in_channels=mlp_hidden_dims, out_channels=embed_dims, kernel_size=1)
25
+ self.drop = nn.Dropout(ffn_drop)
26
+
27
+ self.decompose = nn.Conv2d(
28
+ in_channels=self.mlp_hidden_dims, out_channels=1, kernel_size=1)
29
+ self.sigma = nn.Parameter(
30
+ 1e-5 * torch.ones((1, mlp_hidden_dims, 1, 1)), requires_grad=True)
31
+ self.decompose_act = act_layer()
32
+
33
+ def feat_decompose(self, x):
34
+ x = x + self.sigma * (x - self.decompose_act(self.decompose(x)))
35
+ return x
36
+
37
+ def forward(self, x):
38
+ # proj 1
39
+ x = self.fc1(x)
40
+ x = self.dwconv(x)
41
+ x = self.act(x)
42
+ x = self.drop(x)
43
+ # proj 2
44
+ x = self.feat_decompose(x)
45
+ x = self.fc2(x)
46
+ x = self.drop(x)
47
+ return x
48
+
49
+
50
+ class MultiOrderDWConv(nn.Module):
51
+ """Multi-order Features with Dilated DWConv Kernel in MogaNet."""
52
+
53
+ def __init__(self, embed_dims, dw_dilation=[1, 2, 3], channel_split=[1, 3, 4]):
54
+ super(MultiOrderDWConv, self).__init__()
55
+ self.split_ratio = [i / sum(channel_split) for i in channel_split]
56
+ self.embed_dims_1 = int(self.split_ratio[1] * embed_dims)
57
+ self.embed_dims_2 = int(self.split_ratio[2] * embed_dims)
58
+ self.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2
59
+ self.embed_dims = embed_dims
60
+ assert len(dw_dilation) == len(channel_split) == 3
61
+ assert 1 <= min(dw_dilation) and max(dw_dilation) <= 3
62
+ assert embed_dims % sum(channel_split) == 0
63
+
64
+ # basic DW conv
65
+ self.DW_conv0 = nn.Conv2d(
66
+ in_channels=self.embed_dims, out_channels=self.embed_dims, kernel_size=5,
67
+ padding=(1 + 4 * dw_dilation[0]) // 2,
68
+ groups=self.embed_dims, stride=1, dilation=dw_dilation[0],
69
+ )
70
+ # DW conv 1
71
+ self.DW_conv1 = nn.Conv2d(
72
+ in_channels=self.embed_dims_1, out_channels=self.embed_dims_1, kernel_size=5,
73
+ padding=(1 + 4 * dw_dilation[1]) // 2,
74
+ groups=self.embed_dims_1, stride=1, dilation=dw_dilation[1],
75
+ )
76
+ # DW conv 2
77
+ self.DW_conv2 = nn.Conv2d(
78
+ in_channels=self.embed_dims_2, out_channels=self.embed_dims_2, kernel_size=7,
79
+ padding=(1 + 6 * dw_dilation[2]) // 2,
80
+ groups=self.embed_dims_2, stride=1, dilation=dw_dilation[2],
81
+ )
82
+ # a channel convolution
83
+ self.PW_conv = nn.Conv2d(
84
+ in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
85
+
86
+ def forward(self, x):
87
+ x_0 = self.DW_conv0(x)
88
+ x_1 = self.DW_conv1(
89
+ x_0[:, self.embed_dims_0: self.embed_dims_0+self.embed_dims_1, ...])
90
+ x_2 = self.DW_conv2(
91
+ x_0[:, self.embed_dims-self.embed_dims_2:, ...])
92
+ x = torch.cat([
93
+ x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1)
94
+ x = self.PW_conv(x)
95
+ return x
96
+
97
+
98
+ class MultiOrderGatedAggregation(nn.Module):
99
+ """Spatial Block with Multi-order Gated Aggregation in MogaNet."""
100
+
101
+ def __init__(self, embed_dims, attn_dw_dilation=[1, 2, 3], attn_channel_split=[1, 3, 4], attn_shortcut=True):
102
+ super(MultiOrderGatedAggregation, self).__init__()
103
+ self.embed_dims = embed_dims
104
+ self.attn_shortcut = attn_shortcut
105
+ self.proj_1 = nn.Conv2d(
106
+ in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
107
+ self.gate = nn.Conv2d(
108
+ in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
109
+ self.value = MultiOrderDWConv(
110
+ embed_dims=embed_dims, dw_dilation=attn_dw_dilation, channel_split=attn_channel_split)
111
+ self.proj_2 = nn.Conv2d(
112
+ in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
113
+
114
+ # activation for gating and value
115
+ self.act_value = nn.SiLU()
116
+ self.act_gate = nn.SiLU()
117
+ # decompose
118
+ self.sigma = nn.Parameter(1e-5 * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
119
+
120
+ def feat_decompose(self, x):
121
+ x = self.proj_1(x)
122
+ # x_d: [B, C, H, W] -> [B, C, 1, 1]
123
+ x_d = F.adaptive_avg_pool2d(x, output_size=1)
124
+ x = x + self.sigma * (x - x_d)
125
+ x = self.act_value(x)
126
+ return x
127
+
128
+ def forward(self, x):
129
+ if self.attn_shortcut:
130
+ shortcut = x.clone()
131
+ # proj 1x1
132
+ x = self.feat_decompose(x)
133
+ # gating and value branch
134
+ g = self.gate(x)
135
+ v = self.value(x)
136
+ # aggregation
137
+ x = self.proj_2(self.act_gate(g) * self.act_gate(v))
138
+ if self.attn_shortcut:
139
+ x = x + shortcut
140
+ return x
utilpack/layers/poolformer.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # refer to the code from PoolFormer, Thanks!
2
+ # https://github.com/sail-sg/poolformer/blob/main/models/poolformer.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from timm.layers import DropPath, trunc_normal_
7
+
8
+
9
+ class GroupNorm(nn.GroupNorm):
10
+ """
11
+ Group Normalization with 1 group.
12
+ Input: tensor in shape [B, C, H, W]
13
+ """
14
+ def __init__(self, num_channels, **kwargs):
15
+ super().__init__(1, num_channels, **kwargs)
16
+
17
+
18
+ class Pooling(nn.Module):
19
+ """
20
+ Implementation of pooling for PoolFormer
21
+ --pool_size: pooling size
22
+ """
23
+ def __init__(self, pool_size=3):
24
+ super().__init__()
25
+ self.pool = nn.AvgPool2d(
26
+ pool_size, stride=1, padding=pool_size//2, count_include_pad=False)
27
+
28
+ def forward(self, x):
29
+ return self.pool(x) - x
30
+
31
+
32
+ class Mlp(nn.Module):
33
+ """
34
+ Implementation of MLP with 1*1 convolutions.
35
+ Input: tensor with shape [B, C, H, W]
36
+ """
37
+ def __init__(self, in_features, hidden_features=None,
38
+ out_features=None, act_layer=nn.GELU, drop=0.):
39
+ super().__init__()
40
+ out_features = out_features or in_features
41
+ hidden_features = hidden_features or in_features
42
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
43
+ self.act = act_layer()
44
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
45
+ self.drop = nn.Dropout(drop)
46
+ self.apply(self._init_weights)
47
+
48
+ def _init_weights(self, m):
49
+ if isinstance(m, nn.Conv2d):
50
+ trunc_normal_(m.weight, std=.02)
51
+ if m.bias is not None:
52
+ nn.init.constant_(m.bias, 0)
53
+
54
+ def forward(self, x):
55
+ x = self.fc1(x)
56
+ x = self.act(x)
57
+ x = self.drop(x)
58
+ x = self.fc2(x)
59
+ x = self.drop(x)
60
+ return x
61
+
62
+
63
+ class PoolFormerBlock(nn.Module):
64
+ """
65
+ Implementation of one PoolFormer block.
66
+ --dim: embedding dim
67
+ --pool_size: pooling size
68
+ --mlp_ratio: mlp expansion ratio
69
+ --act_layer: activation
70
+ --norm_layer: normalization
71
+ --drop: dropout rate
72
+ --drop path: Stochastic Depth,
73
+ refer to https://arxiv.org/abs/1603.09382
74
+ --init_value: LayerScale,
75
+ refer to https://arxiv.org/abs/2103.17239
76
+ """
77
+ def __init__(self, dim, pool_size=3, mlp_ratio=4., drop=0., drop_path=0.,
78
+ init_value=1e-5, act_layer=nn.GELU, norm_layer=GroupNorm):
79
+ super().__init__()
80
+
81
+ self.norm1 = norm_layer(dim)
82
+ self.token_mixer = Pooling(pool_size=pool_size)
83
+ self.norm2 = norm_layer(dim)
84
+ mlp_hidden_dim = int(dim * mlp_ratio)
85
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
86
+ act_layer=act_layer, drop=drop)
87
+ # The following two techniques are useful to train deep PoolFormers.
88
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
89
+ self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
90
+ self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
91
+
92
+ def forward(self, x):
93
+ x = x + self.drop_path(
94
+ self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x)))
95
+ x = x + self.drop_path(
96
+ self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
97
+ return x
utilpack/layers/uniformer.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # refer to the code from Uniformer, Thanks!
2
+ # https://github.com/Sense-X/UniFormer/blob/main/image_classification/models/uniformer.py
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ from timm.layers import DropPath, trunc_normal_
8
+
9
+
10
+ class Mlp(nn.Module):
11
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
12
+ super().__init__()
13
+ out_features = out_features or in_features
14
+ hidden_features = hidden_features or in_features
15
+ self.fc1 = nn.Linear(in_features, hidden_features)
16
+ self.act = act_layer()
17
+ self.fc2 = nn.Linear(hidden_features, out_features)
18
+ self.drop = nn.Dropout(drop)
19
+
20
+ def forward(self, x):
21
+ x = self.fc1(x)
22
+ x = self.act(x)
23
+ x = self.drop(x)
24
+ x = self.fc2(x)
25
+ x = self.drop(x)
26
+ return x
27
+
28
+
29
+ class CMlp(nn.Module):
30
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
31
+ super().__init__()
32
+ out_features = out_features or in_features
33
+ hidden_features = hidden_features or in_features
34
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
35
+ self.act = act_layer()
36
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
37
+ self.drop = nn.Dropout(drop)
38
+
39
+ def forward(self, x):
40
+ x = self.fc1(x)
41
+ x = self.act(x)
42
+ x = self.drop(x)
43
+ x = self.fc2(x)
44
+ x = self.drop(x)
45
+ return x
46
+
47
+
48
+ class Attention(nn.Module):
49
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
50
+ super().__init__()
51
+ self.num_heads = num_heads
52
+ head_dim = dim // num_heads
53
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
54
+ self.scale = qk_scale or head_dim ** -0.5
55
+
56
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
57
+ self.attn_drop = nn.Dropout(attn_drop)
58
+ self.proj = nn.Linear(dim, dim)
59
+ self.proj_drop = nn.Dropout(proj_drop)
60
+
61
+ def forward(self, x):
62
+ B, N, C = x.shape
63
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
64
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
65
+
66
+ attn = (q @ k.transpose(-2, -1)) * self.scale
67
+ attn = attn.softmax(dim=-1)
68
+ attn = self.attn_drop(attn)
69
+
70
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
71
+ x = self.proj(x)
72
+ x = self.proj_drop(x)
73
+ return x
74
+
75
+
76
+ class CBlock(nn.Module):
77
+ def __init__(self, dim, num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
78
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
79
+ super().__init__()
80
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
81
+ self.norm1 = nn.BatchNorm2d(dim)
82
+ self.conv1 = nn.Conv2d(dim, dim, 1)
83
+ self.conv2 = nn.Conv2d(dim, dim, 1)
84
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
85
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
86
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
87
+ self.norm2 = nn.BatchNorm2d(dim)
88
+ mlp_hidden_dim = int(dim * mlp_ratio)
89
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
90
+
91
+ self.apply(self._init_weights)
92
+
93
+ def _init_weights(self, m):
94
+ if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
95
+ nn.init.constant_(m.bias, 0)
96
+ nn.init.constant_(m.weight, 1.0)
97
+ elif isinstance(m, nn.Conv2d):
98
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
99
+ fan_out //= m.groups
100
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
101
+ if m.bias is not None:
102
+ m.bias.data.zero_()
103
+
104
+ @torch.jit.ignore
105
+ def no_weight_decay(self):
106
+ return {}
107
+
108
+ def forward(self, x):
109
+ x = x + self.pos_embed(x)
110
+ x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
111
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
112
+ return x
113
+
114
+
115
+ class SABlock(nn.Module):
116
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
117
+ drop_path=0., init_value=1e-6, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
118
+ super().__init__()
119
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
120
+ self.norm1 = norm_layer(dim)
121
+ self.attn = Attention(
122
+ dim,
123
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
124
+ attn_drop=attn_drop, proj_drop=drop)
125
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
126
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
127
+ self.norm2 = norm_layer(dim)
128
+ mlp_hidden_dim = int(dim * mlp_ratio)
129
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
130
+ # layer scale
131
+ self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
132
+ self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
133
+
134
+ self.apply(self._init_weights)
135
+
136
+ def _init_weights(self, m):
137
+ if isinstance(m, nn.Linear):
138
+ trunc_normal_(m.weight, std=.02)
139
+ if isinstance(m, nn.Linear) and m.bias is not None:
140
+ nn.init.constant_(m.bias, 0)
141
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
142
+ nn.init.constant_(m.bias, 0)
143
+ nn.init.constant_(m.weight, 1.0)
144
+
145
+ @torch.jit.ignore
146
+ def no_weight_decay(self):
147
+ return {'gamma_1', 'gamma_2'}
148
+
149
+ def forward(self, x):
150
+ x = x + self.pos_embed(x)
151
+ B, N, H, W = x.shape
152
+ x = x.flatten(2).transpose(1, 2)
153
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
154
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
155
+ x = x.transpose(1, 2).reshape(B, N, H, W)
156
+ return x
utilpack/layers/van.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # refer to the code from VAN, Thanks!
2
+ # https://github.com/Visual-Attention-Network/VAN-Classification
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from timm.layers import DropPath, trunc_normal_
9
+
10
+
11
+ class DWConv(nn.Module):
12
+ def __init__(self, dim=768):
13
+ super(DWConv, self).__init__()
14
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
15
+
16
+ def forward(self, x):
17
+ x = self.dwconv(x)
18
+ return x
19
+
20
+
21
+ class MixMlp(nn.Module):
22
+ def __init__(self,
23
+ in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1) # 1x1
28
+ self.dwconv = DWConv(hidden_features) # CFF: Convlutional feed-forward network
29
+ self.act = act_layer() # GELU
30
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1) # 1x1
31
+ self.drop = nn.Dropout(drop)
32
+ self.apply(self._init_weights)
33
+
34
+ def _init_weights(self, m):
35
+ if isinstance(m, nn.Linear):
36
+ trunc_normal_(m.weight, std=.02)
37
+ if isinstance(m, nn.Linear) and m.bias is not None:
38
+ nn.init.constant_(m.bias, 0)
39
+ elif isinstance(m, nn.LayerNorm):
40
+ nn.init.constant_(m.bias, 0)
41
+ nn.init.constant_(m.weight, 1.0)
42
+ elif isinstance(m, nn.Conv2d):
43
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
44
+ fan_out //= m.groups
45
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
46
+ if m.bias is not None:
47
+ m.bias.data.zero_()
48
+
49
+ def forward(self, x):
50
+ x = self.fc1(x)
51
+ x = self.dwconv(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ x = self.drop(x)
56
+ return x
57
+
58
+
59
+ class LKA(nn.Module):
60
+ def __init__(self, dim):
61
+ super().__init__()
62
+ self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
63
+ self.conv_spatial = nn.Conv2d(
64
+ dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
65
+ self.conv1 = nn.Conv2d(dim, dim, 1)
66
+
67
+
68
+ def forward(self, x):
69
+ u = x.clone()
70
+ attn = self.conv0(x)
71
+ attn = self.conv_spatial(attn)
72
+ attn = self.conv1(attn)
73
+
74
+ return u * attn
75
+
76
+
77
+ class Attention(nn.Module):
78
+ def __init__(self, d_model, attn_shortcut=True):
79
+ super().__init__()
80
+
81
+ self.proj_1 = nn.Conv2d(d_model, d_model, 1)
82
+ self.activation = nn.GELU()
83
+ self.spatial_gating_unit = LKA(d_model)
84
+ self.proj_2 = nn.Conv2d(d_model, d_model, 1)
85
+ self.attn_shortcut = attn_shortcut
86
+
87
+ def forward(self, x):
88
+ if self.attn_shortcut:
89
+ shortcut = x.clone()
90
+ x = self.proj_1(x)
91
+ x = self.activation(x)
92
+ x = self.spatial_gating_unit(x)
93
+ x = self.proj_2(x)
94
+ if self.attn_shortcut:
95
+ x = x + shortcut
96
+ return x
97
+
98
+
99
+ class VANBlock(nn.Module):
100
+ def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., init_value=1e-2, act_layer=nn.GELU, attn_shortcut=True):
101
+ super().__init__()
102
+ self.norm1 = nn.BatchNorm2d(dim)
103
+ self.attn = Attention(dim, attn_shortcut=attn_shortcut)
104
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
105
+
106
+ self.norm2 = nn.BatchNorm2d(dim)
107
+ mlp_hidden_dim = int(dim * mlp_ratio)
108
+ self.mlp = MixMlp(
109
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
110
+
111
+ self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
112
+ self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
113
+
114
+ def forward(self, x):
115
+ x = x + self.drop_path(
116
+ self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
117
+ x = x + self.drop_path(
118
+ self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
119
+ return x
utilpack/mau_modules.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+ class MAUCell(nn.Module):
7
+
8
+ def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, tau, cell_mode):
9
+ super(MAUCell, self).__init__()
10
+
11
+ self.num_hidden = num_hidden
12
+ # self.padding = (filter_size[0] // 2, filter_size[1] // 2)
13
+ self.padding = filter_size // 2
14
+ self.cell_mode = cell_mode
15
+ self.d = num_hidden * height * width
16
+ self.tau = tau
17
+ self.states = ['residual', 'normal']
18
+ if not self.cell_mode in self.states:
19
+ raise AssertionError
20
+ self.conv_t = nn.Sequential(
21
+ nn.Conv2d(in_channel, 3 * num_hidden, kernel_size=filter_size,
22
+ stride=stride, padding=self.padding),
23
+ nn.LayerNorm([3 * num_hidden, height, width])
24
+ )
25
+ self.conv_t_next = nn.Sequential(
26
+ nn.Conv2d(in_channel, num_hidden, kernel_size=filter_size,
27
+ stride=stride, padding=self.padding),
28
+ nn.LayerNorm([num_hidden, height, width])
29
+ )
30
+ self.conv_s = nn.Sequential(
31
+ nn.Conv2d(num_hidden, 3 * num_hidden, kernel_size=filter_size,
32
+ stride=stride, padding=self.padding),
33
+ nn.LayerNorm([3 * num_hidden, height, width])
34
+ )
35
+ self.conv_s_next = nn.Sequential(
36
+ nn.Conv2d(num_hidden, num_hidden, kernel_size=filter_size,
37
+ stride=stride, padding=self.padding),
38
+ nn.LayerNorm([num_hidden, height, width])
39
+ )
40
+ self.softmax = nn.Softmax(dim=0)
41
+
42
+ def forward(self, T_t, S_t, t_att, s_att):
43
+ s_next = self.conv_s_next(S_t)
44
+ t_next = self.conv_t_next(T_t)
45
+ weights_list = []
46
+ for i in range(self.tau):
47
+ weights_list.append((s_att[i] * s_next).sum(dim=(1, 2, 3)) / math.sqrt(self.d))
48
+ weights_list = torch.stack(weights_list, dim=0)
49
+ weights_list = torch.reshape(weights_list, (*weights_list.shape, 1, 1, 1))
50
+ weights_list = self.softmax(weights_list)
51
+ T_trend = t_att * weights_list
52
+ T_trend = T_trend.sum(dim=0)
53
+ t_att_gate = torch.sigmoid(t_next)
54
+ T_fusion = T_t * t_att_gate + (1 - t_att_gate) * T_trend
55
+ T_concat = self.conv_t(T_fusion)
56
+ S_concat = self.conv_s(S_t)
57
+ t_g, t_t, t_s = torch.split(T_concat, self.num_hidden, dim=1)
58
+ s_g, s_t, s_s = torch.split(S_concat, self.num_hidden, dim=1)
59
+ T_gate = torch.sigmoid(t_g)
60
+ S_gate = torch.sigmoid(s_g)
61
+ T_new = T_gate * t_t + (1 - T_gate) * s_t
62
+ S_new = S_gate * s_s + (1 - S_gate) * t_s
63
+
64
+ if self.cell_mode == 'residual':
65
+ S_new = S_new + S_t
66
+ return T_new, S_new
utilpack/mim_modules.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class MIMBlock(nn.Module):
6
+
7
+ def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
8
+ super(MIMBlock, self).__init__()
9
+
10
+ self.convlstm_c = None
11
+ self.num_hidden = num_hidden
12
+ self.padding = filter_size // 2
13
+ self._forget_bias = 1.0
14
+
15
+ self.ct_weight = nn.Parameter(torch.zeros(num_hidden*2, height, width))
16
+ self.oc_weight = nn.Parameter(torch.zeros(num_hidden, height, width))
17
+
18
+ if layer_norm:
19
+ self.conv_t_cc = nn.Sequential(
20
+ nn.Conv2d(in_channel, num_hidden * 3, kernel_size=filter_size,
21
+ stride=stride, padding=self.padding, bias=False),
22
+ nn.LayerNorm([num_hidden * 3, height, width])
23
+ )
24
+ self.conv_s_cc = nn.Sequential(
25
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
26
+ stride=stride, padding=self.padding, bias=False),
27
+ nn.LayerNorm([num_hidden * 4, height, width])
28
+ )
29
+ self.conv_x_cc = nn.Sequential(
30
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
31
+ stride=stride, padding=self.padding, bias=False),
32
+ nn.LayerNorm([num_hidden * 4, height, width])
33
+ )
34
+ self.conv_h_concat = nn.Sequential(
35
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
36
+ stride=stride, padding=self.padding, bias=False),
37
+ nn.LayerNorm([num_hidden * 4, height, width])
38
+ )
39
+ self.conv_x_concat = nn.Sequential(
40
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
41
+ stride=stride, padding=self.padding, bias=False),
42
+ nn.LayerNorm([num_hidden * 4, height, width])
43
+ )
44
+ else:
45
+ self.conv_t_cc = nn.Sequential(
46
+ nn.Conv2d(in_channel, num_hidden * 3, kernel_size=filter_size,
47
+ stride=stride, padding=self.padding, bias=False),
48
+ )
49
+ self.conv_s_cc = nn.Sequential(
50
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
51
+ stride=stride, padding=self.padding, bias=False),
52
+ )
53
+ self.conv_x_cc = nn.Sequential(
54
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
55
+ stride=stride, padding=self.padding, bias=False),
56
+ )
57
+ self.conv_h_concat = nn.Sequential(
58
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
59
+ stride=stride, padding=self.padding, bias=False),
60
+ )
61
+ self.conv_x_concat = nn.Sequential(
62
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
63
+ stride=stride, padding=self.padding, bias=False),
64
+ )
65
+ self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
66
+ stride=1, padding=0, bias=False)
67
+
68
+ def _init_state(self, inputs):
69
+ return torch.zeros_like(inputs)
70
+
71
+ def MIMS(self, x, h_t, c_t):
72
+ if h_t is None:
73
+ h_t = self._init_state(x)
74
+ if c_t is None:
75
+ c_t = self._init_state(x)
76
+
77
+ h_concat = self.conv_h_concat(h_t)
78
+ i_h, g_h, f_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
79
+
80
+ ct_activation = torch.mul(c_t.repeat(1,2,1,1), self.ct_weight)
81
+ i_c, f_c = torch.split(ct_activation, self.num_hidden, dim=1)
82
+
83
+ i_ = i_h + i_c
84
+ f_ = f_h + f_c
85
+ g_ = g_h
86
+ o_ = o_h
87
+
88
+ if x != None:
89
+ x_concat = self.conv_x_concat(x)
90
+ i_x, g_x, f_x, o_x = torch.split(x_concat, self.num_hidden, dim=1)
91
+
92
+ i_ = i_ + i_x
93
+ f_ = f_ + f_x
94
+ g_ = g_ + g_x
95
+ o_ = o_ + o_x
96
+
97
+ i_ = torch.sigmoid(i_)
98
+ f_ = torch.sigmoid(f_ + self._forget_bias)
99
+ c_new = f_ * c_t + i_ * torch.tanh(g_)
100
+
101
+ o_c = torch.mul(c_new, self.oc_weight)
102
+
103
+ h_new = torch.sigmoid(o_ + o_c) * torch.tanh(c_new)
104
+
105
+ return h_new, c_new
106
+
107
+ def forward(self, x, diff_h, h, c, m):
108
+ h = self._init_state(x) if h is None else h
109
+ c = self._init_state(x) if c is None else c
110
+ m = self._init_state(x) if m is None else m
111
+ diff_h = self._init_state(x) if diff_h is None else diff_h
112
+
113
+ t_cc = self.conv_t_cc(h)
114
+ s_cc = self.conv_s_cc(m)
115
+ x_cc = self.conv_x_cc(x)
116
+
117
+ i_s, g_s, f_s, o_s = torch.split(s_cc, self.num_hidden, dim=1)
118
+ i_t, g_t, o_t = torch.split(t_cc, self.num_hidden, dim=1)
119
+ i_x, g_x, f_x, o_x = torch.split(x_cc, self.num_hidden, dim=1)
120
+
121
+ i = torch.sigmoid(i_x + i_t)
122
+ i_ = torch.sigmoid(i_x + i_s)
123
+ g = torch.tanh(g_x + g_t)
124
+ g_ = torch.tanh(g_x + g_s)
125
+ f_ = torch.sigmoid(f_x + f_s + self._forget_bias)
126
+ o = torch.sigmoid(o_x + o_t + o_s)
127
+ new_m = f_ * m + i_ * g_
128
+
129
+ c, self.convlstm_c = self.MIMS(diff_h, c, self.convlstm_c \
130
+ if self.convlstm_c is None else self.convlstm_c.detach())
131
+
132
+ new_c = c + i * g
133
+ cell = torch.cat((new_c, new_m), 1)
134
+ new_h = o * torch.tanh(self.conv_last(cell))
135
+
136
+ return new_h, new_c, new_m
137
+
138
+
139
+ class MIMN(nn.Module):
140
+
141
+ def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
142
+ super(MIMN, self).__init__()
143
+
144
+ self.num_hidden = num_hidden
145
+ self.padding = filter_size // 2
146
+ self._forget_bias = 1.0
147
+
148
+ self.ct_weight = nn.Parameter(torch.zeros(num_hidden*2, height, width))
149
+ self.oc_weight = nn.Parameter(torch.zeros(num_hidden, height, width))
150
+
151
+ if layer_norm:
152
+ self.conv_h_concat = nn.Sequential(
153
+ nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
154
+ stride=stride, padding=self.padding, bias=False),
155
+ nn.LayerNorm([num_hidden * 4, height, width])
156
+ )
157
+ self.conv_x_concat = nn.Sequential(
158
+ nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
159
+ stride=stride, padding=self.padding, bias=False),
160
+ nn.LayerNorm([num_hidden * 4, height, width])
161
+ )
162
+ else:
163
+ self.conv_h_concat = nn.Sequential(
164
+ nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
165
+ stride=stride, padding=self.padding, bias=False),
166
+ )
167
+ self.conv_x_concat = nn.Sequential(
168
+ nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
169
+ stride=stride, padding=self.padding, bias=False),
170
+ )
171
+ self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
172
+ stride=1, padding=0, bias=False)
173
+
174
+ def _init_state(self, inputs):
175
+ return torch.zeros_like(inputs)
176
+
177
+ def forward(self, x, h_t, c_t):
178
+ if h_t is None:
179
+ h_t = self._init_state(x)
180
+ if c_t is None:
181
+ c_t = self._init_state(x)
182
+
183
+ h_concat = self.conv_h_concat(h_t)
184
+ i_h, g_h, f_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
185
+
186
+ ct_activation = torch.mul(c_t.repeat(1,2,1,1), self.ct_weight)
187
+ i_c, f_c = torch.split(ct_activation, self.num_hidden, dim=1)
188
+
189
+ i_ = i_h + i_c
190
+ f_ = f_h + f_c
191
+ g_ = g_h
192
+ o_ = o_h
193
+
194
+ if x != None:
195
+ x_concat = self.conv_x_concat(x)
196
+ i_x, g_x, f_x, o_x = torch.split(x_concat, self.num_hidden, dim=1)
197
+
198
+ i_ = i_ + i_x
199
+ f_ = f_ + f_x
200
+ g_ = g_ + g_x
201
+ o_ = o_ + o_x
202
+
203
+ i_ = torch.sigmoid(i_)
204
+ f_ = torch.sigmoid(f_ + self._forget_bias)
205
+ c_new = f_ * c_t + i_ * torch.tanh(g_)
206
+
207
+ o_c = torch.mul(c_new, self.oc_weight)
208
+
209
+ h_new = torch.sigmoid(o_ + o_c) * torch.tanh(c_new)
210
+
211
+ return h_new, c_new
utilpack/mmvp_modules.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ResidualDenseBlock_4C(nn.Module):
6
+ def __init__(self, nf=64, gc = 32, bias=True):
7
+ super(ResidualDenseBlock_4C, self).__init__()
8
+ # gc: growth channel, i.e. intermediate channels
9
+
10
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
11
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
12
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
13
+ self.conv4 = nn.Conv2d(nf + 3 * gc, nf, 3, 1, 1, bias=bias)
14
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
15
+
16
+ # initialization
17
+ # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
18
+
19
+ def forward(self, x):
20
+ x1 = self.lrelu(self.conv1(x))
21
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
22
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
23
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
24
+ return x4 * 0.2 + x
25
+
26
+ class RRDB(nn.Module):
27
+ '''Residual in Residual Dense Block'''
28
+
29
+ def __init__(self, nf):
30
+ super(RRDB, self).__init__()
31
+ gc = nf // 2
32
+ self.RDB1 = ResidualDenseBlock_4C(nf, gc)
33
+ self.RDB2 = ResidualDenseBlock_4C(nf, gc)
34
+ self.RDB3 = ResidualDenseBlock_4C(nf, gc)
35
+
36
+ def forward(self, x):
37
+ out = self.RDB1(x)
38
+ out = self.RDB2(out)
39
+ out = self.RDB3(out)
40
+ return out * 0.2 + x
41
+
42
+
43
+
44
+ class Up(nn.Module):
45
+ """Upscaling then double conv"""
46
+
47
+ def __init__(self, in_channels, out_channels, bilinear=True, skip=True, scale=2, bn=True, motion=False):
48
+ super().__init__()
49
+ factor = scale
50
+ # if bilinear, use the normal convolutions to reduce the number of channels
51
+ if bilinear:
52
+ if skip:
53
+ self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True)
54
+ self.conv = ConvLayer(in_channels, out_channels, bn=bn)
55
+
56
+ else:
57
+ self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True)
58
+ self.conv = ConvLayer(in_channels, out_channels)
59
+ else:
60
+ if skip:
61
+ self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor)
62
+ self.conv = ConvLayer(out_channels*2, out_channels, bn=bn, motion=motion)
63
+ else:
64
+ self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor)
65
+ self.conv = ConvLayer(out_channels, out_channels, bn=bn, motion=motion)
66
+
67
+ def forward(self, x1, x2=None):
68
+
69
+ x1 = self.up(x1)
70
+ if x2 is None:
71
+ return self.conv(x1)
72
+ # input is CHW
73
+ diffY = x2.size()[2] - x1.size()[2]
74
+ diffX = x2.size()[3] - x1.size()[3]
75
+
76
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
77
+ diffY // 2, diffY - diffY // 2])
78
+ # if you have padding issues, see
79
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
80
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
81
+ x = torch.cat([x2, x1], dim=1)
82
+ return self.conv(x)
83
+
84
+
85
+ class ResBlock(nn.Module):
86
+ def __init__(self, in_channels, out_channels, downsample=False,
87
+ upsample=False, skip=False, factor=2, motion=False):
88
+ super().__init__()
89
+ self.upsample = upsample
90
+ self.maxpool= None
91
+ if downsample:
92
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
93
+ self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2)
94
+ if factor == 4:
95
+ self.maxpool = nn.MaxPool2d(2)
96
+
97
+ elif upsample:
98
+ self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor)
99
+
100
+ if motion:
101
+ self.shortcut = nn.Sequential(nn.Upsample(scale_factor=factor,
102
+ mode='bilinear',
103
+ align_corners=True),
104
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1),
105
+ nn.BatchNorm2d(out_channels))
106
+ else:
107
+ self.shortcut = nn.Sequential(nn.Upsample(scale_factor=factor,
108
+ mode='bilinear',
109
+ align_corners=True),
110
+ nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=1))
111
+
112
+ else:
113
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
114
+ self.shortcut = nn.Sequential()
115
+
116
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
117
+
118
+ def forward(self, input):
119
+ shortcut = self.shortcut(input)
120
+ input = nn.ReLU()(self.conv1(input))
121
+ input = nn.ReLU()(self.conv2(input))
122
+ input = input + shortcut
123
+ if self.maxpool is not None:
124
+ input = self.maxpool(input)
125
+ return nn.LeakyReLU()(input)
126
+
127
+ class ConvLayer(nn.Module):
128
+ """(convolution => [BN] => ReLU) * 2"""
129
+
130
+ def __init__(self, in_channels, out_channels, mid_channels=None, bn=True, motion=False, dilation=1):
131
+ super().__init__()
132
+ if not mid_channels:
133
+ mid_channels = out_channels
134
+
135
+ self.conv = nn.Sequential(
136
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
137
+ nn.BatchNorm2d(out_channels),
138
+ nn.ReLU(inplace=True),
139
+ ) if motion else nn.Sequential(
140
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation, bias=False, dilation=dilation),
141
+ nn.ReLU(inplace=True)
142
+ )
143
+
144
+ def forward(self, x):
145
+ return self.conv(x)
146
+
147
+
148
+ class Conv3D(nn.Module):
149
+ def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
150
+ super(Conv3D, self).__init__()
151
+ self.conv3d = nn.Conv3d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)
152
+ self.bn3d = nn.BatchNorm3d(out_channel)
153
+
154
+ def forward(self, x):
155
+ # input x: (batch, seq, c, h, w)
156
+ x = x.permute(0, 2, 1, 3, 4).contiguous() # (batch, c, seq_len, h, w)
157
+ x = F.leaky_relu(self.bn3d(self.conv3d(x)))
158
+ x = x.permute(0, 2, 1, 3, 4).contiguous() # (batch, seq_len, c, h, w)
159
+
160
+ return x
161
+
162
+ class MatrixPredictor3DConv(nn.Module):
163
+ def __init__(self, hidden_len=64):
164
+ super(MatrixPredictor3DConv, self).__init__()
165
+ self.unet_base = hidden_len #64
166
+ self.hidden_len = hidden_len #64
167
+ self.conv_pre_1 = nn.Conv2d(hidden_len,hidden_len, kernel_size=3, stride=1, padding=1)
168
+ self.conv_pre_2 = nn.Conv2d(hidden_len, hidden_len, kernel_size=3, stride=1, padding=1)
169
+
170
+ self.conv3d_1 = Conv3D(self.unet_base, self.unet_base, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1))
171
+ self.conv3d_2 = Conv3D(self.unet_base*2, self.unet_base*2, kernel_size=(3 , 3, 3), stride=1, padding=(0, 1, 1))
172
+
173
+ self.conv1_1 = nn.Conv2d(hidden_len, self.unet_base, kernel_size=3, stride=2, padding=1)
174
+ self.conv2_1 = nn.Conv2d(self.unet_base, self.unet_base * 2, kernel_size=3, stride=2, padding=1)
175
+
176
+ self.conv3_1 = nn.Conv2d(self.unet_base * 3, self.unet_base, kernel_size=3, stride=1, padding=1)
177
+ self.conv4_1 = nn.Conv2d(self.unet_base, self.hidden_len, kernel_size=3, stride=1, padding=1)
178
+
179
+ self.bn_pre_1 = nn.BatchNorm2d(hidden_len)
180
+ self.bn_pre_2 = nn.BatchNorm2d(hidden_len)
181
+ self.bn1_1 = nn.BatchNorm2d(self.unet_base)
182
+ self.bn2_1 = nn.BatchNorm2d(self.unet_base * 2)
183
+ self.bn3_1 = nn.BatchNorm2d(self.unet_base)
184
+ self.bn4_1 = nn.BatchNorm2d(self.hidden_len)
185
+
186
+
187
+ def forward(self,x):
188
+ # x [B,T,C,32,32]
189
+ # out: [B,C,32,32]
190
+ batch, seq, z, h, w = x.size()
191
+ x = x.reshape(-1, x.size(-3), x.size(-2), x.size(-1))
192
+ x = F.leaky_relu(self.bn_pre_1(self.conv_pre_1(x)))
193
+ x = F.leaky_relu(self.bn_pre_2(self.conv_pre_2(x)))
194
+ x_1 = F.leaky_relu(self.bn1_1(self.conv1_1(x)))
195
+
196
+ x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)).contiguous() # (batch, seq, c, h, w)
197
+ x_1 = self.conv3d_1(x_1) # (batch, seq, c, h, w), 1st temporal conv
198
+ x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() # (batch * seq, c, h, w)
199
+ x_2 = F.leaky_relu(self.bn2_1(self.conv2_1(x_1))) # (batch * seq, c, h // 2, w // 2)
200
+ x_2 = x_2.view(batch, -1, x_2.size(1), x_2.size(2), x_2.size(3)).contiguous() # (batch, seq, c, h, w)
201
+ x_2 = self.conv3d_2(x_2) # (batch, seq=1, c, h // 2, w // 2), 2nd temporal conv
202
+ x_2 = x_2.view(-1, x_2.size(2), x_2.size(3), x_2.size(4)).contiguous() # (batch * seq, c, h//2, w//2), seq = 1
203
+
204
+ x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)) # (batch, seq, c, h, w)
205
+ x_1 = x_1.permute(0, 2, 1, 3, 4).contiguous() # (batch, c, seq, h, w)
206
+ x_1 = F.adaptive_max_pool3d(x_1, (1, None, None)) # (batch, c, 1, h, w)
207
+ x_1 = x_1.permute(0, 2, 1, 3, 4).contiguous() # (batch, 1, c, h, w)
208
+ x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() # (batch*1, c, h, w)
209
+ x_3 = F.leaky_relu(self.bn3_1(self.conv3_1(torch.cat((F.interpolate(x_2, scale_factor=(2, 2)), x_1), dim=1))))
210
+ x = x.view(batch, -1, x.size(1), x.size(2), x.size(3)) # (batch, seq, 1, h, w)
211
+ x = F.leaky_relu(self.bn4_1(self.conv4_1(F.interpolate(x_3, scale_factor=(2, 2)))))
212
+ return x
213
+
214
+ class SimpleMatrixPredictor3DConv_direct(nn.Module):
215
+ def __init__(self, T, hidden_len=64, image_pred=False, aft_seq_length=10):
216
+ super(SimpleMatrixPredictor3DConv_direct, self).__init__()
217
+ self.unet_base = hidden_len #64
218
+ self.hidden_len = hidden_len #64
219
+ self.conv_pre_1 = nn.Conv2d(hidden_len,hidden_len, kernel_size=3, stride=1, padding=1)
220
+ self.conv_pre_2 = nn.Conv2d(hidden_len, hidden_len, kernel_size=3, stride=1, padding=1)
221
+ self.fut_len = aft_seq_length
222
+
223
+ self.conv3d_1 = Conv3D(self.unet_base, self.unet_base, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1))
224
+
225
+ if self.fut_len > 1 :
226
+ self.temporal_layer = Conv3D(self.unet_base*2, self.unet_base*2, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1))
227
+ else:
228
+ self.temporal_layer = nn.Sequential(
229
+ nn.Conv2d(self.unet_base *2, self.unet_base * 2, kernel_size=3, stride=1, padding=1),
230
+ nn.LeakyReLU())
231
+
232
+ input_len = T if image_pred else T - 1
233
+ self.conv_translate = nn.Sequential(
234
+ nn.Conv2d(self.unet_base * input_len , self.unet_base * self.fut_len, kernel_size=1, stride=1, padding=0),
235
+ nn.LeakyReLU())
236
+
237
+ self.conv1_1 = nn.Conv2d(hidden_len, self.unet_base, kernel_size=3, stride=2, padding=1)
238
+ self.conv2_1 = nn.Conv2d(self.unet_base, self.unet_base * 2, kernel_size=3, stride=2, padding=1)
239
+
240
+ self.conv3_1 = nn.Conv2d(self.unet_base * 3, self.unet_base, kernel_size=3, stride=1, padding=1)
241
+ self.conv4_1 = nn.Conv2d(self.unet_base, self.hidden_len, kernel_size=3, stride=1, padding=1)
242
+
243
+ self.bn_pre_1 = nn.BatchNorm2d(hidden_len)
244
+ self.bn_pre_2 = nn.BatchNorm2d(hidden_len)
245
+ self.bn1_1 = nn.BatchNorm2d(self.unet_base)
246
+ self.bn2_1 = nn.BatchNorm2d(self.unet_base * 2)
247
+ self.bn3_1 = nn.BatchNorm2d(self.unet_base)
248
+ self.bn4_1 = nn.BatchNorm2d(self.hidden_len)
249
+ self.bn_translate = nn.BatchNorm2d(self.unet_base * self.fut_len)
250
+
251
+
252
+ def forward(self,x):
253
+ # x [B,T,C,32,32]
254
+ # out: [B,C,32,32]
255
+ batch, seq, z, h, w = x.size()
256
+ x = x.reshape(-1, x.size(-3), x.size(-2), x.size(-1))
257
+ x = F.leaky_relu(self.bn_pre_1(self.conv_pre_1(x)))
258
+ x = F.leaky_relu(self.bn_pre_2(self.conv_pre_2(x)))
259
+ x_1 = F.leaky_relu(self.bn1_1(self.conv1_1(x)))
260
+
261
+ x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)).contiguous() # (batch, seq, c, h, w)
262
+
263
+ x_1 = self.conv3d_1(x_1) # (batch, seq, c, h, w), 1st temporal conv
264
+ batch, seq, c, h, w = x_1.shape
265
+ x_tmp = x_1.reshape(batch,-1,h,w)
266
+ x_tmp = self.bn_translate(self.conv_translate(x_tmp))
267
+ x_1 = x_tmp.reshape(batch,self.fut_len,c,h,w)
268
+ x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() # (batch * seq, c, h, w)
269
+ x_2 = F.leaky_relu(self.bn2_1(self.conv2_1(x_1))) # (batch * seq, c, h // 2, w // 2)
270
+ if self.fut_len > 1:
271
+ x_2 = x_2.view(batch, -1, x_2.size(1), x_2.size(2), x_2.size(3)).contiguous() # (batch, seq, c, h, w)
272
+ x_2 = self.temporal_layer(x_2) # (batch, seq=10, c, h // 2, w // 2)
273
+
274
+ x_2 = x_2.view(-1, x_2.size(2), x_2.size(3), x_2.size(4)).contiguous() # (batch * seq, c, h//2, w//2), seq = 1
275
+ else:
276
+ x_2 = self.temporal_layer(x_2) # (batch * seq,c, h // 2, w // 2)
277
+
278
+ x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)) # (batch, seq, c, h, w)
279
+
280
+ x_1 = x_1.reshape(-1, x_1.size(2), x_1.size(3), x_1.size(4))
281
+
282
+
283
+ x_3 = F.leaky_relu(self.bn3_1(self.conv3_1(torch.cat((F.interpolate(x_2, size=x_1.shape[2:]), x_1), dim=1))))
284
+ x = x.view(batch, -1, x.size(1), x.size(2), x.size(3)) # (batch, seq, 1, h, w)
285
+ x = F.leaky_relu(self.bn4_1(self.conv4_1(F.interpolate(x_3, size = x.shape[3:]))))
286
+
287
+ return x
288
+
289
+ class PredictModel(nn.Module):
290
+ def __init__(self, T, hidden_len=32, aft_seq_length=10, mx_h=32, mx_w=32, use_direct_predictor=True):
291
+ super(PredictModel, self).__init__()
292
+ self.mx_h = mx_h
293
+ self.mx_w = mx_w
294
+ self.hidden_len = hidden_len
295
+ self.fut_len = aft_seq_length
296
+ self.conv1 = nn.Conv2d( 1, hidden_len, kernel_size=3, padding=1, bias=False)
297
+ self.fuse_conv = nn.Conv2d(hidden_len*2, hidden_len, kernel_size=3, padding=1, bias=False)
298
+ if use_direct_predictor:
299
+ self.predictor = SimpleMatrixPredictor3DConv_direct(T=T, hidden_len=hidden_len, aft_seq_length=aft_seq_length)
300
+ else:
301
+ self.predictor = MatrixPredictor3DConv(hidden_len)
302
+ self.out_conv = nn.Conv2d(hidden_len, 1, kernel_size=3, padding=1, bias=False)
303
+ self.softmax = nn.Softmax(dim=-1)
304
+ self.sigmoid = nn.Sigmoid()
305
+
306
+ def res_interpolate(self,in_tensor,template_tensor):
307
+ '''
308
+ in_tensor: batch,c,h'w',H'W'
309
+ tempolate_tensor: batch,c,hw,HW
310
+ out_tensor: batch,c,hw,HW
311
+ '''
312
+ out_tensor = F.interpolate(in_tensor,template_tensor.shape[-2:]) # (BThw,target_h,target_w)
313
+
314
+ return out_tensor
315
+
316
+ def forward(self,matrix_seq, softmax=False, res=None):
317
+
318
+ B,T,hw,window_size = matrix_seq.size()
319
+
320
+ matrix_seq = matrix_seq.reshape(-1,hw,self.mx_h,self.mx_w) # (BT,hw,hw)
321
+ matrix_seq = matrix_seq.reshape(B*T*hw,self.mx_h,self.mx_w).unsqueeze(1) # (BThw,1,h,w)
322
+
323
+ x = self.conv1(matrix_seq)
324
+ x = x.reshape(B,T,hw,-1,self.mx_h,self.mx_w)
325
+ x = x.permute(0,2,1,3,4,5).reshape(B*hw,T,-1,self.mx_h,self.mx_w)
326
+ emb = self.predictor(x)
327
+
328
+ emb = emb.reshape(B*hw*self.fut_len,-1,self.mx_h,self.mx_w)
329
+ res_emb = emb.clone()
330
+ if res is not None:
331
+ template = emb.clone().reshape(B,hw,emb.shape[1],-1).permute(0,2,1,3)
332
+ in_tensor = res.clone().reshape(B,hw//4,emb.shape[1],-1).permute(0,2,1,3)
333
+
334
+ res_tensor = self.res_interpolate(in_tensor,template).permute(0,2,1,3).reshape(emb.shape)
335
+
336
+ emb = self.fuse_conv(torch.cat([emb,res_tensor],dim=1))
337
+
338
+ out = self.out_conv(emb) #(Bhwt,16,h//4,w//4)
339
+
340
+ out = out.reshape(B,hw,-1,self.mx_h,self.mx_w)
341
+ out = out.permute(0,2,1,3,4)
342
+ out = out.reshape(B,-1,hw,window_size)
343
+
344
+ if softmax:
345
+ out = out.view(B,out.shape[1],-1)
346
+ out = self.softmax(out)
347
+ out = out.reshape(B,-1,hw,window_size)
348
+
349
+ return out,res_emb
utilpack/phydnet_modules.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from numpy import *
4
+ from numpy.linalg import *
5
+ from scipy.special import factorial
6
+ from functools import reduce
7
+
8
+ __all__ = ['M2K','K2M']
9
+
10
+
11
+ class PhyCell_Cell(nn.Module):
12
+
13
+ def __init__(self, input_dim, F_hidden_dim, kernel_size, bias=1):
14
+ super(PhyCell_Cell, self).__init__()
15
+ self.input_dim = input_dim
16
+ self.F_hidden_dim = F_hidden_dim
17
+ self.kernel_size = kernel_size
18
+ self.padding = kernel_size[0] // 2, kernel_size[1] // 2
19
+ self.bias = bias
20
+
21
+ self.F = nn.Sequential()
22
+ self.F.add_module('conv1', nn.Conv2d(in_channels=input_dim, out_channels=F_hidden_dim,
23
+ kernel_size=self.kernel_size, stride=(1,1), padding=self.padding))
24
+ self.F.add_module('bn1',nn.GroupNorm(7 ,F_hidden_dim))
25
+ self.F.add_module('conv2', nn.Conv2d(in_channels=F_hidden_dim, out_channels=input_dim,
26
+ kernel_size=(1,1), stride=(1,1), padding=(0,0)))
27
+
28
+ self.convgate = nn.Conv2d(in_channels=self.input_dim + self.input_dim,
29
+ out_channels=self.input_dim,
30
+ kernel_size=(3,3),
31
+ padding=(1,1), bias=self.bias)
32
+
33
+ def forward(self, x, hidden): # x [batch_size, hidden_dim, height, width]
34
+ combined = torch.cat([x, hidden], dim=1) # concatenate along channel axis
35
+ combined_conv = self.convgate(combined)
36
+ K = torch.sigmoid(combined_conv)
37
+ hidden_tilde = hidden + self.F(hidden) # prediction
38
+ next_hidden = hidden_tilde + K * (x-hidden_tilde) # correction , Haddamard product
39
+ return next_hidden
40
+
41
+
42
+ class PhyCell(nn.Module):
43
+
44
+ def __init__(self, input_shape, input_dim, F_hidden_dims, n_layers, kernel_size, device):
45
+ super(PhyCell, self).__init__()
46
+ self.input_shape = input_shape
47
+ self.input_dim = input_dim
48
+ self.F_hidden_dims = F_hidden_dims
49
+ self.n_layers = n_layers
50
+ self.kernel_size = kernel_size
51
+ self.H = []
52
+ self.device = device
53
+
54
+ cell_list = []
55
+ for i in range(0, self.n_layers):
56
+ cell_list.append(PhyCell_Cell(input_dim=input_dim,
57
+ F_hidden_dim=self.F_hidden_dims[i],
58
+ kernel_size=self.kernel_size))
59
+ self.cell_list = nn.ModuleList(cell_list)
60
+
61
+
62
+ def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]
63
+ batch_size = input_.data.size()[0]
64
+ if (first_timestep):
65
+ self.initHidden(batch_size) # init Hidden at each forward start
66
+ for j, cell in enumerate(self.cell_list):
67
+ self.H[j] = self.H[j].to(input_.device)
68
+ if j==0: # bottom layer
69
+ self.H[j] = cell(input_, self.H[j])
70
+ else:
71
+ self.H[j] = cell(self.H[j-1],self.H[j])
72
+ return self.H, self.H
73
+
74
+ def initHidden(self, batch_size):
75
+ self.H = []
76
+ for i in range(self.n_layers):
77
+ self.H.append(torch.zeros(
78
+ batch_size, self.input_dim, self.input_shape[0], self.input_shape[1]).to(self.device))
79
+
80
+ def setHidden(self, H):
81
+ self.H = H
82
+
83
+
84
+ class PhyD_ConvLSTM_Cell(nn.Module):
85
+ def __init__(self, input_shape, input_dim, hidden_dim, kernel_size, bias=1):
86
+ """
87
+ input_shape: (int, int)
88
+ Height and width of input tensor as (height, width).
89
+ input_dim: int
90
+ Number of channels of input tensor.
91
+ hidden_dim: int
92
+ Number of channels of hidden state.
93
+ kernel_size: (int, int)
94
+ Size of the convolutional kernel.
95
+ bias: bool
96
+ Whether or not to add the bias.
97
+ """
98
+ super(PhyD_ConvLSTM_Cell, self).__init__()
99
+
100
+ self.height, self.width = input_shape
101
+ self.input_dim = input_dim
102
+ self.hidden_dim = hidden_dim
103
+ self.kernel_size = kernel_size
104
+ self.padding = kernel_size[0] // 2, kernel_size[1] // 2
105
+ self.bias = bias
106
+
107
+ self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
108
+ out_channels=4 * self.hidden_dim,
109
+ kernel_size=self.kernel_size,
110
+ padding=self.padding, bias=self.bias)
111
+
112
+ # we implement LSTM that process only one timestep
113
+ def forward(self, x, hidden): # x [batch, hidden_dim, width, height]
114
+ h_cur, c_cur = hidden
115
+
116
+ combined = torch.cat([x, h_cur], dim=1) # concatenate along channel axis
117
+ combined_conv = self.conv(combined)
118
+ cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
119
+ i = torch.sigmoid(cc_i)
120
+ f = torch.sigmoid(cc_f)
121
+ o = torch.sigmoid(cc_o)
122
+ g = torch.tanh(cc_g)
123
+
124
+ c_next = f * c_cur + i * g
125
+ h_next = o * torch.tanh(c_next)
126
+ return h_next, c_next
127
+
128
+
129
+ class PhyD_ConvLSTM(nn.Module):
130
+
131
+ def __init__(self, input_shape, input_dim, hidden_dims, n_layers, kernel_size, device):
132
+ super(PhyD_ConvLSTM, self).__init__()
133
+ self.input_shape = input_shape
134
+ self.input_dim = input_dim
135
+ self.hidden_dims = hidden_dims
136
+ self.n_layers = n_layers
137
+ self.kernel_size = kernel_size
138
+ self.H, self.C = [], []
139
+ self.device = device
140
+
141
+ cell_list = []
142
+ for i in range(0, self.n_layers):
143
+ cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1]
144
+ print('layer ', i, 'input dim ', cur_input_dim, ' hidden dim ', self.hidden_dims[i])
145
+ cell_list.append(PhyD_ConvLSTM_Cell(input_shape=self.input_shape,
146
+ input_dim=cur_input_dim,
147
+ hidden_dim=self.hidden_dims[i],
148
+ kernel_size=self.kernel_size))
149
+ self.cell_list = nn.ModuleList(cell_list)
150
+
151
+ def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]
152
+ batch_size = input_.data.size()[0]
153
+ if (first_timestep):
154
+ self.initHidden(batch_size) # init Hidden at each forward start
155
+ for j, cell in enumerate(self.cell_list):
156
+ self.H[j] = self.H[j].to(input_.device)
157
+ self.C[j] = self.C[j].to(input_.device)
158
+ if j==0: # bottom layer
159
+ self.H[j], self.C[j] = cell(input_, (self.H[j],self.C[j]))
160
+ else:
161
+ self.H[j], self.C[j] = cell(self.H[j-1],(self.H[j],self.C[j]))
162
+ return (self.H,self.C) , self.H # (hidden, output)
163
+
164
+ def initHidden(self,batch_size):
165
+ self.H, self.C = [],[]
166
+ for i in range(self.n_layers):
167
+ self.H.append(torch.zeros(
168
+ batch_size, self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device))
169
+ self.C.append(torch.zeros(
170
+ batch_size, self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device))
171
+
172
+ def setHidden(self, hidden):
173
+ H,C = hidden
174
+ self.H, self.C = H,C
175
+
176
+
177
+ class dcgan_conv(nn.Module):
178
+
179
+ def __init__(self, nin, nout, stride):
180
+ super(dcgan_conv, self).__init__()
181
+ self.main = nn.Sequential(
182
+ nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=(3,3),
183
+ stride=stride, padding=1),
184
+ nn.GroupNorm(16, nout),
185
+ nn.LeakyReLU(0.2, inplace=True),
186
+ )
187
+
188
+ def forward(self, input):
189
+ return self.main(input)
190
+
191
+
192
+ class dcgan_upconv(nn.Module):
193
+
194
+ def __init__(self, nin, nout, stride):
195
+ super(dcgan_upconv, self).__init__()
196
+ if stride==2:
197
+ output_padding = 1
198
+ else:
199
+ output_padding = 0
200
+ self.main = nn.Sequential(
201
+ nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=(3,3),
202
+ stride=stride, padding=1, output_padding=output_padding),
203
+ nn.GroupNorm(16, nout),
204
+ nn.LeakyReLU(0.2, inplace=True),
205
+ )
206
+
207
+ def forward(self, input):
208
+ return self.main(input)
209
+
210
+
211
+ class encoder_E(nn.Module):
212
+
213
+ def __init__(self, nc=1, nf=32, patch_size=4):
214
+ super(encoder_E, self).__init__()
215
+ assert patch_size in [2, 4]
216
+ stride_2 = patch_size // 2
217
+ # input is (1) x 64 x 64
218
+ self.c1 = dcgan_conv(nc, nf, stride=2) # (32) x 32 x 32
219
+ self.c2 = dcgan_conv(nf, nf, stride=1) # (32) x 32 x 32
220
+ self.c3 = dcgan_conv(nf, 2*nf, stride=stride_2) # (64) x 16 x 16
221
+
222
+ def forward(self, input):
223
+ h1 = self.c1(input)
224
+ h2 = self.c2(h1)
225
+ h3 = self.c3(h2)
226
+ return h3
227
+
228
+
229
+ class decoder_D(nn.Module):
230
+
231
+ def __init__(self, nc=1, nf=32, patch_size=4):
232
+ super(decoder_D, self).__init__()
233
+ assert patch_size in [2, 4]
234
+ stride_2 = patch_size // 2
235
+ output_padding = 1 if stride_2==2 else 0
236
+ self.upc1 = dcgan_upconv(2*nf, nf, stride=2) #(32) x 32 x 32
237
+ self.upc2 = dcgan_upconv(nf, nf, stride=1) #(32) x 32 x 32
238
+ self.upc3 = nn.ConvTranspose2d(in_channels=nf, out_channels=nc, kernel_size=(3,3),
239
+ stride=stride_2, padding=1,
240
+ output_padding=output_padding) #(nc) x 64 x 64
241
+
242
+ def forward(self, input):
243
+ d1 = self.upc1(input)
244
+ d2 = self.upc2(d1)
245
+ d3 = self.upc3(d2)
246
+ return d3
247
+
248
+
249
+ class encoder_specific(nn.Module):
250
+
251
+ def __init__(self, nc=64, nf=64):
252
+ super(encoder_specific, self).__init__()
253
+ self.c1 = dcgan_conv(nc, nf, stride=1) # (64) x 16 x 16
254
+ self.c2 = dcgan_conv(nf, nf, stride=1) # (64) x 16 x 16
255
+
256
+ def forward(self, input):
257
+ h1 = self.c1(input)
258
+ h2 = self.c2(h1)
259
+ return h2
260
+
261
+
262
+ class decoder_specific(nn.Module):
263
+
264
+ def __init__(self, nc=64, nf=64):
265
+ super(decoder_specific, self).__init__()
266
+ self.upc1 = dcgan_upconv(nf, nf, stride=1) #(64) x 16 x 16
267
+ self.upc2 = dcgan_upconv(nf, nc, stride=1) #(32) x 32 x 32
268
+
269
+ def forward(self, input):
270
+ d1 = self.upc1(input)
271
+ d2 = self.upc2(d1)
272
+ return d2
273
+
274
+
275
+ class PhyD_EncoderRNN(torch.nn.Module):
276
+
277
+ def __init__(self, phycell, convcell, in_channel=1, patch_size=4):
278
+ super(PhyD_EncoderRNN, self).__init__()
279
+ self.encoder_E = encoder_E(nc=in_channel, patch_size=patch_size) # general encoder 64x64x1 -> 32x32x32
280
+ self.encoder_Ep = encoder_specific() # specific image encoder 32x32x32 -> 16x16x64
281
+ self.encoder_Er = encoder_specific()
282
+ self.decoder_Dp = decoder_specific() # specific image decoder 16x16x64 -> 32x32x32
283
+ self.decoder_Dr = decoder_specific()
284
+ self.decoder_D = decoder_D(nc=in_channel, patch_size=patch_size) # general decoder 32x32x32 -> 64x64x1
285
+
286
+ self.phycell = phycell
287
+ self.convcell = convcell
288
+
289
+ def forward(self, input, first_timestep=False, decoding=False):
290
+ input = self.encoder_E(input) # general encoder 64x64x1 -> 32x32x32
291
+
292
+ if decoding: # input=None in decoding phase
293
+ input_phys = None
294
+ else:
295
+ input_phys = self.encoder_Ep(input)
296
+ input_conv = self.encoder_Er(input)
297
+
298
+ hidden1, output1 = self.phycell(input_phys, first_timestep)
299
+ hidden2, output2 = self.convcell(input_conv, first_timestep)
300
+
301
+ decoded_Dp = self.decoder_Dp(output1[-1])
302
+ decoded_Dr = self.decoder_Dr(output2[-1])
303
+
304
+ out_phys = torch.sigmoid(self.decoder_D(decoded_Dp)) # partial reconstructions for vizualization
305
+ out_conv = torch.sigmoid(self.decoder_D(decoded_Dr))
306
+
307
+ concat = decoded_Dp + decoded_Dr
308
+ output_image = torch.sigmoid( self.decoder_D(concat ))
309
+ return out_phys, hidden1, output_image, out_phys, out_conv
310
+
311
+
312
+ def _apply_axis_left_dot(x, mats):
313
+ assert x.dim() == len(mats)+1
314
+ sizex = x.size()
315
+ k = x.dim()-1
316
+ for i in range(k):
317
+ x = tensordot(mats[k-i-1], x, dim=[1,k])
318
+ x = x.permute([k,]+list(range(k))).contiguous()
319
+ x = x.view(sizex)
320
+ return x
321
+
322
+ def _apply_axis_right_dot(x, mats):
323
+ assert x.dim() == len(mats)+1
324
+ sizex = x.size()
325
+ k = x.dim()-1
326
+ x = x.permute(list(range(1,k+1))+[0,])
327
+ for i in range(k):
328
+ x = tensordot(x, mats[i], dim=[0,0])
329
+ x = x.contiguous()
330
+ x = x.view(sizex)
331
+ return x
332
+
333
+ class _MK(nn.Module):
334
+ def __init__(self, shape):
335
+ super(_MK, self).__init__()
336
+ self._size = torch.Size(shape)
337
+ self._dim = len(shape)
338
+ M = []
339
+ invM = []
340
+ assert len(shape) > 0
341
+ j = 0
342
+ for l in shape:
343
+ M.append(zeros((l,l)))
344
+ for i in range(l):
345
+ M[-1][i] = ((arange(l)-(l-1)//2)**i)/factorial(i)
346
+ invM.append(inv(M[-1]))
347
+ self.register_buffer('_M'+str(j), torch.from_numpy(M[-1]))
348
+ self.register_buffer('_invM'+str(j), torch.from_numpy(invM[-1]))
349
+ j += 1
350
+
351
+ @property
352
+ def M(self):
353
+ return list(self._buffers['_M'+str(j)] for j in range(self.dim()))
354
+ @property
355
+ def invM(self):
356
+ return list(self._buffers['_invM'+str(j)] for j in range(self.dim()))
357
+
358
+ def size(self):
359
+ return self._size
360
+ def dim(self):
361
+ return self._dim
362
+ def _packdim(self, x):
363
+ assert x.dim() >= self.dim()
364
+ if x.dim() == self.dim():
365
+ x = x[newaxis,:]
366
+ x = x.contiguous()
367
+ x = x.view([-1,]+list(x.size()[-self.dim():]))
368
+ return x
369
+
370
+ def forward(self):
371
+ pass
372
+
373
+
374
+ class M2K(_MK):
375
+ """
376
+ convert moment matrix to convolution kernel
377
+ Arguments:
378
+ shape (tuple of int): kernel shape
379
+ Usage:
380
+ m2k = M2K([5,5])
381
+ m = torch.randn(5,5,dtype=torch.float64)
382
+ k = m2k(m)
383
+ """
384
+ def __init__(self, shape):
385
+ super(M2K, self).__init__(shape)
386
+ def forward(self, m):
387
+ """
388
+ m (Tensor): torch.size=[...,*self.shape]
389
+ """
390
+ sizem = m.size()
391
+ m = self._packdim(m)
392
+ m = _apply_axis_left_dot(m, self.invM)
393
+ m = m.view(sizem)
394
+ return m
395
+
396
+
397
+ class K2M(_MK):
398
+ """
399
+ convert convolution kernel to moment matrix
400
+ Arguments:
401
+ shape (tuple of int): kernel shape
402
+ Usage:
403
+ k2m = K2M([5,5])
404
+ k = torch.randn(5,5,dtype=torch.float64)
405
+ m = k2m(k)
406
+ """
407
+ def __init__(self, shape):
408
+ super(K2M, self).__init__(shape)
409
+ def forward(self, k):
410
+ """
411
+ k (Tensor): torch.size=[...,*self.shape]
412
+ """
413
+ sizek = k.size()
414
+ k = self._packdim(k)
415
+ k = _apply_axis_left_dot(k, self.M)
416
+ k = k.view(sizek)
417
+ return k
418
+
419
+
420
+ def tensordot(a,b,dim):
421
+ """
422
+ tensordot in PyTorch, see numpy.tensordot?
423
+ """
424
+ l = lambda x,y:x*y
425
+ if isinstance(dim,int):
426
+ a = a.contiguous()
427
+ b = b.contiguous()
428
+ sizea = a.size()
429
+ sizeb = b.size()
430
+ sizea0 = sizea[:-dim]
431
+ sizea1 = sizea[-dim:]
432
+ sizeb0 = sizeb[:dim]
433
+ sizeb1 = sizeb[dim:]
434
+ N = reduce(l, sizea1, 1)
435
+ assert reduce(l, sizeb0, 1) == N
436
+ else:
437
+ adims = dim[0]
438
+ bdims = dim[1]
439
+ adims = [adims,] if isinstance(adims, int) else adims
440
+ bdims = [bdims,] if isinstance(bdims, int) else bdims
441
+ adims_ = set(range(a.dim())).difference(set(adims))
442
+ adims_ = list(adims_)
443
+ adims_.sort()
444
+ perma = adims_+adims
445
+ bdims_ = set(range(b.dim())).difference(set(bdims))
446
+ bdims_ = list(bdims_)
447
+ bdims_.sort()
448
+ permb = bdims+bdims_
449
+ a = a.permute(*perma).contiguous()
450
+ b = b.permute(*permb).contiguous()
451
+
452
+ sizea = a.size()
453
+ sizeb = b.size()
454
+ sizea0 = sizea[:-len(adims)]
455
+ sizea1 = sizea[-len(adims):]
456
+ sizeb0 = sizeb[:len(bdims)]
457
+ sizeb1 = sizeb[len(bdims):]
458
+ N = reduce(l, sizea1, 1)
459
+ assert reduce(l, sizeb0, 1) == N
460
+ a = a.view([-1,N])
461
+ b = b.view([N,-1])
462
+ c = a@b
463
+ return c.view(sizea0+sizeb1)
utilpack/predrnn_modules.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SpatioTemporalLSTMCell(nn.Module):
6
+
7
+ def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
8
+ super(SpatioTemporalLSTMCell, self).__init__()
9
+
10
+ self.num_hidden = num_hidden
11
+ self.padding = filter_size // 2
12
+ self._forget_bias = 1.0
13
+ if layer_norm:
14
+ self.conv_x = nn.Sequential(
15
+ nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
16
+ stride=stride, padding=self.padding, bias=False),
17
+ nn.LayerNorm([num_hidden * 7, height, width])
18
+ )
19
+ self.conv_h = nn.Sequential(
20
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
21
+ stride=stride, padding=self.padding, bias=False),
22
+ nn.LayerNorm([num_hidden * 4, height, width])
23
+ )
24
+ self.conv_m = nn.Sequential(
25
+ nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
26
+ stride=stride, padding=self.padding, bias=False),
27
+ nn.LayerNorm([num_hidden * 3, height, width])
28
+ )
29
+ self.conv_o = nn.Sequential(
30
+ nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
31
+ stride=stride, padding=self.padding, bias=False),
32
+ nn.LayerNorm([num_hidden, height, width])
33
+ )
34
+ else:
35
+ self.conv_x = nn.Sequential(
36
+ nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
37
+ stride=stride, padding=self.padding, bias=False),
38
+ )
39
+ self.conv_h = nn.Sequential(
40
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
41
+ stride=stride, padding=self.padding, bias=False),
42
+ )
43
+ self.conv_m = nn.Sequential(
44
+ nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
45
+ stride=stride, padding=self.padding, bias=False),
46
+ )
47
+ self.conv_o = nn.Sequential(
48
+ nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
49
+ stride=stride, padding=self.padding, bias=False),
50
+ )
51
+ self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
52
+ stride=1, padding=0, bias=False)
53
+
54
+ def forward(self, x_t, h_t, c_t, m_t):
55
+ x_concat = self.conv_x(x_t)
56
+ h_concat = self.conv_h(h_t)
57
+ m_concat = self.conv_m(m_t)
58
+ i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = \
59
+ torch.split(x_concat, self.num_hidden, dim=1)
60
+ i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
61
+ i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)
62
+
63
+ i_t = torch.sigmoid(i_x + i_h)
64
+ f_t = torch.sigmoid(f_x + f_h + self._forget_bias)
65
+ g_t = torch.tanh(g_x + g_h)
66
+
67
+ c_new = f_t * c_t + i_t * g_t
68
+
69
+ i_t_prime = torch.sigmoid(i_x_prime + i_m)
70
+ f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)
71
+ g_t_prime = torch.tanh(g_x_prime + g_m)
72
+
73
+ m_new = f_t_prime * m_t + i_t_prime * g_t_prime
74
+
75
+ mem = torch.cat((c_new, m_new), 1)
76
+ o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))
77
+ h_new = o_t * torch.tanh(self.conv_last(mem))
78
+
79
+ return h_new, c_new, m_new
utilpack/predrnnpp_modules.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class CausalLSTMCell(nn.Module):
6
+
7
+ def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
8
+ super(CausalLSTMCell, self).__init__()
9
+
10
+ self.num_hidden = num_hidden
11
+ self.padding = filter_size // 2
12
+ self._forget_bias = 1.0
13
+ if layer_norm:
14
+ self.conv_x = nn.Sequential(
15
+ nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
16
+ stride=stride, padding=self.padding, bias=False),
17
+ nn.LayerNorm([num_hidden * 7, height, width])
18
+ )
19
+ self.conv_h = nn.Sequential(
20
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
21
+ stride=stride, padding=self.padding, bias=False),
22
+ nn.LayerNorm([num_hidden * 4, height, width])
23
+ )
24
+ self.conv_c = nn.Sequential(
25
+ nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
26
+ stride=stride, padding=self.padding, bias=False),
27
+ nn.LayerNorm([num_hidden * 3, height, width])
28
+ )
29
+ self.conv_m = nn.Sequential(
30
+ nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
31
+ stride=stride, padding=self.padding, bias=False),
32
+ nn.LayerNorm([num_hidden * 3, height, width])
33
+ )
34
+ self.conv_o = nn.Sequential(
35
+ nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
36
+ stride=stride, padding=self.padding, bias=False),
37
+ nn.LayerNorm([num_hidden, height, width])
38
+ )
39
+ self.conv_c2m = nn.Sequential(
40
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
41
+ stride=stride, padding=self.padding, bias=False),
42
+ nn.LayerNorm([num_hidden * 4, height, width])
43
+ )
44
+ self.conv_om = nn.Sequential(
45
+ nn.Conv2d(num_hidden, num_hidden, kernel_size=filter_size,
46
+ stride=stride, padding=self.padding, bias=False),
47
+ nn.LayerNorm([num_hidden, height, width])
48
+ )
49
+ else:
50
+ self.conv_x = nn.Sequential(
51
+ nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
52
+ stride=stride, padding=self.padding, bias=False),
53
+ )
54
+ self.conv_h = nn.Sequential(
55
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
56
+ stride=stride, padding=self.padding, bias=False),
57
+ )
58
+ self.conv_c = nn.Sequential(
59
+ nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
60
+ stride=stride, padding=self.padding, bias=False),
61
+ )
62
+ self.conv_m = nn.Sequential(
63
+ nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
64
+ stride=stride, padding=self.padding, bias=False),
65
+ )
66
+ self.conv_o = nn.Sequential(
67
+ nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
68
+ stride=stride, padding=self.padding, bias=False),
69
+ )
70
+ self.conv_c2m = nn.Sequential(
71
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
72
+ stride=stride, padding=self.padding, bias=False),
73
+ )
74
+ self.conv_om = nn.Sequential(
75
+ nn.Conv2d(num_hidden, num_hidden, kernel_size=filter_size,
76
+ stride=stride, padding=self.padding, bias=False),
77
+ )
78
+ self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
79
+ stride=1, padding=0, bias=False)
80
+
81
+ def forward(self, x_t, h_t, c_t, m_t):
82
+ x_concat = self.conv_x(x_t)
83
+ h_concat = self.conv_h(h_t)
84
+ c_concat = self.conv_c(c_t)
85
+ m_concat = self.conv_m(m_t)
86
+ i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = \
87
+ torch.split(x_concat, self.num_hidden, dim=1)
88
+ i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
89
+ i_m, f_m, m_m = torch.split(m_concat, self.num_hidden, dim=1)
90
+ i_c, f_c, g_c = torch.split(c_concat, self.num_hidden, dim=1)
91
+
92
+ i_t = torch.sigmoid(i_x + i_h + i_c)
93
+ f_t = torch.sigmoid(f_x + f_h + f_c + self._forget_bias)
94
+ g_t = torch.tanh(g_x + g_h + g_c)
95
+
96
+ c_new = f_t * c_t + i_t * g_t
97
+
98
+ c2m = self.conv_c2m(c_new)
99
+ i_c, g_c, f_c, o_c = torch.split(c2m, self.num_hidden, dim=1)
100
+
101
+ i_t_prime = torch.sigmoid(i_x_prime + i_m + i_c)
102
+ f_t_prime = torch.sigmoid(f_x_prime + f_m + f_c + self._forget_bias)
103
+ g_t_prime = torch.tanh(g_x_prime + g_c)
104
+
105
+ m_new = f_t_prime * torch.tanh(m_m) + i_t_prime * g_t_prime
106
+ o_m = self.conv_om(m_new)
107
+
108
+ o_t = torch.tanh(o_x + o_h + o_c + o_m)
109
+ mem = torch.cat((c_new, m_new), 1)
110
+ h_new = o_t * torch.tanh(self.conv_last(mem))
111
+
112
+ return h_new, c_new, m_new
113
+
114
+
115
+ class GHU(nn.Module):
116
+ def __init__(self, in_channel, num_hidden, height, width, filter_size,
117
+ stride, layer_norm, initializer=0.001):
118
+ super(GHU, self).__init__()
119
+
120
+ self.filter_size = filter_size
121
+ self.padding = filter_size // 2
122
+ self.num_hidden = num_hidden
123
+ self.layer_norm = layer_norm
124
+
125
+ if layer_norm:
126
+ self.z_concat = nn.Sequential(
127
+ nn.Conv2d(in_channel, num_hidden * 2, kernel_size=filter_size,
128
+ stride=stride, padding=self.padding, bias=False),
129
+ nn.LayerNorm([num_hidden, height, width])
130
+ )
131
+ self.x_concat = nn.Sequential(
132
+ nn.Conv2d(in_channel, num_hidden * 2, kernel_size=filter_size,
133
+ stride=stride, padding=self.padding, bias=False),
134
+ nn.LayerNorm([num_hidden, height, width])
135
+ )
136
+ else:
137
+ self.z_concat = nn.Sequential(
138
+ nn.Conv2d(in_channel, num_hidden * 2, kernel_size=filter_size,
139
+ stride=stride, padding=self.padding, bias=False),
140
+ )
141
+ self.x_concat = nn.Sequential(
142
+ nn.Conv2d(in_channel, num_hidden * 2, kernel_size=filter_size,
143
+ stride=stride, padding=self.padding, bias=False),
144
+ )
145
+
146
+
147
+ if initializer != -1:
148
+ self.initializer = initializer
149
+ self.apply(self._init_weights)
150
+
151
+ def _init_weights(self, m):
152
+ if isinstance(m, (nn.Conv2d)):
153
+ nn.init.uniform_(m.weight, -self.initializer, self.initializer)
154
+
155
+ def _init_state(self, inputs):
156
+ return torch.zeros_like(inputs)
157
+
158
+ def forward(self, x, z):
159
+ if z is None:
160
+ z = self._init_state(x)
161
+ z_concat = self.z_concat(z)
162
+ x_concat = self.x_concat(x)
163
+
164
+ gates = x_concat + z_concat
165
+ p, u = torch.split(gates, self.num_hidden, dim=1)
166
+ p = torch.tanh(p)
167
+ u = torch.sigmoid(u)
168
+ z_new = u * p + (1-u) * z
169
+ return z_new
utilpack/predrnnv2_modules.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SpatioTemporalLSTMCellv2(nn.Module):
6
+
7
+ def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
8
+ super(SpatioTemporalLSTMCellv2, self).__init__()
9
+
10
+ self.num_hidden = num_hidden
11
+ self.padding = filter_size // 2
12
+ self._forget_bias = 1.0
13
+ if layer_norm:
14
+ self.conv_x = nn.Sequential(
15
+ nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
16
+ stride=stride, padding=self.padding, bias=False),
17
+ nn.LayerNorm([num_hidden * 7, height, width])
18
+ )
19
+ self.conv_h = nn.Sequential(
20
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
21
+ stride=stride, padding=self.padding, bias=False),
22
+ nn.LayerNorm([num_hidden * 4, height, width])
23
+ )
24
+ self.conv_m = nn.Sequential(
25
+ nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
26
+ stride=stride, padding=self.padding, bias=False),
27
+ nn.LayerNorm([num_hidden * 3, height, width])
28
+ )
29
+ self.conv_o = nn.Sequential(
30
+ nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
31
+ stride=stride, padding=self.padding, bias=False),
32
+ nn.LayerNorm([num_hidden, height, width])
33
+ )
34
+ else:
35
+ self.conv_x = nn.Sequential(
36
+ nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
37
+ stride=stride, padding=self.padding, bias=False),
38
+ )
39
+ self.conv_h = nn.Sequential(
40
+ nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
41
+ stride=stride, padding=self.padding, bias=False),
42
+ )
43
+ self.conv_m = nn.Sequential(
44
+ nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
45
+ stride=stride, padding=self.padding, bias=False),
46
+ )
47
+ self.conv_o = nn.Sequential(
48
+ nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
49
+ stride=stride, padding=self.padding, bias=False),
50
+ )
51
+ self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
52
+ stride=1, padding=0, bias=False)
53
+
54
+
55
+ def forward(self, x_t, h_t, c_t, m_t):
56
+ x_concat = self.conv_x(x_t)
57
+ h_concat = self.conv_h(h_t)
58
+ m_concat = self.conv_m(m_t)
59
+ i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = \
60
+ torch.split(x_concat, self.num_hidden, dim=1)
61
+ i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
62
+ i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)
63
+
64
+ i_t = torch.sigmoid(i_x + i_h)
65
+ f_t = torch.sigmoid(f_x + f_h + self._forget_bias)
66
+ g_t = torch.tanh(g_x + g_h)
67
+
68
+ delta_c = i_t * g_t
69
+ c_new = f_t * c_t + delta_c
70
+
71
+ i_t_prime = torch.sigmoid(i_x_prime + i_m)
72
+ f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)
73
+ g_t_prime = torch.tanh(g_x_prime + g_m)
74
+
75
+ delta_m = i_t_prime * g_t_prime
76
+ m_new = f_t_prime * m_t + delta_m
77
+
78
+ mem = torch.cat((c_new, m_new), 1)
79
+ o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))
80
+ h_new = o_t * torch.tanh(self.conv_last(mem))
81
+
82
+ return h_new, c_new, m_new, delta_c, delta_m
utilpack/simvp_modules.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from timm.layers import DropPath, trunc_normal_
6
+ from timm.models.convnext import ConvNeXtBlock
7
+ from timm.models.mlp_mixer import MixerBlock
8
+ from timm.models.swin_transformer import SwinTransformerBlock, window_partition, window_reverse
9
+ from timm.models.vision_transformer import Block as ViTBlock
10
+
11
+ from .layers import (HorBlock, ChannelAggregationFFN, MultiOrderGatedAggregation,
12
+ PoolFormerBlock, CBlock, SABlock, MixMlp, VANBlock)
13
+
14
+
15
+ class BasicConv2d(nn.Module):
16
+
17
+ def __init__(self,
18
+ in_channels,
19
+ out_channels,
20
+ kernel_size=3,
21
+ stride=1,
22
+ padding=0,
23
+ dilation=1,
24
+ upsampling=False,
25
+ act_norm=False,
26
+ act_inplace=True):
27
+ super(BasicConv2d, self).__init__()
28
+ self.act_norm = act_norm
29
+ if upsampling is True:
30
+ self.conv = nn.Sequential(*[
31
+ nn.Conv2d(in_channels, out_channels*4, kernel_size=kernel_size,
32
+ stride=1, padding=padding, dilation=dilation),
33
+ nn.PixelShuffle(2)
34
+ ])
35
+ else:
36
+ self.conv = nn.Conv2d(
37
+ in_channels, out_channels, kernel_size=kernel_size,
38
+ stride=stride, padding=padding, dilation=dilation)
39
+
40
+ self.norm = nn.GroupNorm(2, out_channels)
41
+ self.act = nn.SiLU(inplace=act_inplace)
42
+
43
+ self.apply(self._init_weights)
44
+
45
+ def _init_weights(self, m):
46
+ if isinstance(m, (nn.Conv2d)):
47
+ trunc_normal_(m.weight, std=.02)
48
+ nn.init.constant_(m.bias, 0)
49
+
50
+ def forward(self, x):
51
+ y = self.conv(x)
52
+ if self.act_norm:
53
+ y = self.act(self.norm(y))
54
+ return y
55
+
56
+
57
+ class ConvSC(nn.Module):
58
+
59
+ def __init__(self,
60
+ C_in,
61
+ C_out,
62
+ kernel_size=3,
63
+ downsampling=False,
64
+ upsampling=False,
65
+ act_norm=True,
66
+ act_inplace=True):
67
+ super(ConvSC, self).__init__()
68
+
69
+ stride = 2 if downsampling is True else 1
70
+ padding = (kernel_size - stride + 1) // 2
71
+
72
+ self.conv = BasicConv2d(C_in, C_out, kernel_size=kernel_size, stride=stride,
73
+ upsampling=upsampling, padding=padding,
74
+ act_norm=act_norm, act_inplace=act_inplace)
75
+
76
+ def forward(self, x):
77
+ y = self.conv(x)
78
+ return y
79
+
80
+
81
+ class GroupConv2d(nn.Module):
82
+
83
+ def __init__(self,
84
+ in_channels,
85
+ out_channels,
86
+ kernel_size=3,
87
+ stride=1,
88
+ padding=0,
89
+ groups=1,
90
+ act_norm=False,
91
+ act_inplace=True):
92
+ super(GroupConv2d, self).__init__()
93
+ self.act_norm=act_norm
94
+ if in_channels % groups != 0:
95
+ groups=1
96
+ self.conv = nn.Conv2d(
97
+ in_channels, out_channels, kernel_size=kernel_size,
98
+ stride=stride, padding=padding, groups=groups)
99
+ self.norm = nn.GroupNorm(groups,out_channels)
100
+ self.activate = nn.LeakyReLU(0.2, inplace=act_inplace)
101
+
102
+ def forward(self, x):
103
+ y = self.conv(x)
104
+ if self.act_norm:
105
+ y = self.activate(self.norm(y))
106
+ return y
107
+
108
+
109
+ class gInception_ST(nn.Module):
110
+ """A IncepU block for SimVP"""
111
+
112
+ def __init__(self, C_in, C_hid, C_out, incep_ker = [3,5,7,11], groups = 8):
113
+ super(gInception_ST, self).__init__()
114
+ self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
115
+
116
+ layers = []
117
+ for ker in incep_ker:
118
+ layers.append(GroupConv2d(
119
+ C_hid, C_out, kernel_size=ker, stride=1,
120
+ padding=ker//2, groups=groups, act_norm=True))
121
+ self.layers = nn.Sequential(*layers)
122
+
123
+ def forward(self, x):
124
+ x = self.conv1(x)
125
+ y = 0
126
+ for layer in self.layers:
127
+ y += layer(x)
128
+ return y
129
+
130
+
131
+ class AttentionModule(nn.Module):
132
+ """Large Kernel Attention for SimVP"""
133
+
134
+ def __init__(self, dim, kernel_size, dilation=3):
135
+ super().__init__()
136
+ d_k = 2 * dilation - 1
137
+ d_p = (d_k - 1) // 2
138
+ dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1)
139
+ dd_p = (dilation * (dd_k - 1) // 2)
140
+
141
+ self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim)
142
+ self.conv_spatial = nn.Conv2d(
143
+ dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation)
144
+ self.conv1 = nn.Conv2d(dim, 2*dim, 1)
145
+
146
+ def forward(self, x):
147
+ u = x.clone()
148
+ attn = self.conv0(x) # depth-wise conv
149
+ attn = self.conv_spatial(attn) # depth-wise dilation convolution
150
+
151
+ f_g = self.conv1(attn)
152
+ split_dim = f_g.shape[1] // 2
153
+ f_x, g_x = torch.split(f_g, split_dim, dim=1)
154
+ return torch.sigmoid(g_x) * f_x
155
+
156
+
157
+ class SpatialAttention(nn.Module):
158
+ """A Spatial Attention block for SimVP"""
159
+
160
+ def __init__(self, d_model, kernel_size=21, attn_shortcut=True):
161
+ super().__init__()
162
+
163
+ self.proj_1 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
164
+ self.activation = nn.GELU() # GELU
165
+ self.spatial_gating_unit = AttentionModule(d_model, kernel_size)
166
+ self.proj_2 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
167
+ self.attn_shortcut = attn_shortcut
168
+
169
+ def forward(self, x):
170
+ if self.attn_shortcut:
171
+ shortcut = x.clone()
172
+ x = self.proj_1(x)
173
+ x = self.activation(x)
174
+ x = self.spatial_gating_unit(x)
175
+ x = self.proj_2(x)
176
+ if self.attn_shortcut:
177
+ x = x + shortcut
178
+ return x
179
+
180
+
181
+ class GASubBlock(nn.Module):
182
+ """A GABlock (gSTA) for SimVP"""
183
+
184
+ def __init__(self, dim, kernel_size=21, mlp_ratio=4.,
185
+ drop=0., drop_path=0.1, init_value=1e-2, act_layer=nn.GELU):
186
+ super().__init__()
187
+ self.norm1 = nn.BatchNorm2d(dim)
188
+ self.attn = SpatialAttention(dim, kernel_size)
189
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
190
+
191
+ self.norm2 = nn.BatchNorm2d(dim)
192
+ mlp_hidden_dim = int(dim * mlp_ratio)
193
+ self.mlp = MixMlp(
194
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
195
+
196
+ self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
197
+ self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
198
+
199
+ self.apply(self._init_weights)
200
+
201
+ def _init_weights(self, m):
202
+ if isinstance(m, nn.Linear):
203
+ trunc_normal_(m.weight, std=.02)
204
+ if isinstance(m, nn.Linear) and m.bias is not None:
205
+ nn.init.constant_(m.bias, 0)
206
+ elif isinstance(m, nn.LayerNorm):
207
+ nn.init.constant_(m.bias, 0)
208
+ nn.init.constant_(m.weight, 1.0)
209
+ elif isinstance(m, nn.Conv2d):
210
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
211
+ fan_out //= m.groups
212
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
213
+ if m.bias is not None:
214
+ m.bias.data.zero_()
215
+
216
+ @torch.jit.ignore
217
+ def no_weight_decay(self):
218
+ return {'layer_scale_1', 'layer_scale_2'}
219
+
220
+ def forward(self, x):
221
+ x = x + self.drop_path(
222
+ self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
223
+ x = x + self.drop_path(
224
+ self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
225
+ return x
226
+
227
+
228
+ class ConvMixerSubBlock(nn.Module):
229
+ """A block of ConvMixer."""
230
+
231
+ def __init__(self, dim, kernel_size=9, activation=nn.GELU):
232
+ super().__init__()
233
+ # spatial mixing
234
+ self.conv_dw = nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same")
235
+ self.act_1 = activation()
236
+ self.norm_1 = nn.BatchNorm2d(dim)
237
+ # channel mixing
238
+ self.conv_pw = nn.Conv2d(dim, dim, kernel_size=1)
239
+ self.act_2 = activation()
240
+ self.norm_2 = nn.BatchNorm2d(dim)
241
+
242
+ self.apply(self._init_weights)
243
+
244
+ def _init_weights(self, m):
245
+ if isinstance(m, nn.BatchNorm2d):
246
+ nn.init.constant_(m.bias, 0)
247
+ nn.init.constant_(m.weight, 1.0)
248
+ elif isinstance(m, nn.Conv2d):
249
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
250
+ fan_out //= m.groups
251
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
252
+ if m.bias is not None:
253
+ m.bias.data.zero_()
254
+
255
+ @torch.jit.ignore
256
+ def no_weight_decay(self):
257
+ return dict()
258
+
259
+ def forward(self, x):
260
+ x = x + self.norm_1(self.act_1(self.conv_dw(x)))
261
+ x = self.norm_2(self.act_2(self.conv_pw(x)))
262
+ return x
263
+
264
+
265
+ class ConvNeXtSubBlock(ConvNeXtBlock):
266
+ """A block of ConvNeXt."""
267
+
268
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1):
269
+ super().__init__(dim, mlp_ratio=mlp_ratio,
270
+ drop_path=drop_path, ls_init_value=1e-6, conv_mlp=True)
271
+ self.apply(self._init_weights)
272
+
273
+ def _init_weights(self, m):
274
+ if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
275
+ nn.init.constant_(m.bias, 0)
276
+ nn.init.constant_(m.weight, 1.0)
277
+ elif isinstance(m, nn.Conv2d):
278
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
279
+ fan_out //= m.groups
280
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
281
+ if m.bias is not None:
282
+ m.bias.data.zero_()
283
+
284
+ @torch.jit.ignore
285
+ def no_weight_decay(self):
286
+ return {'gamma'}
287
+
288
+ def forward(self, x):
289
+ x = x + self.drop_path(
290
+ self.gamma.reshape(1, -1, 1, 1) * self.mlp(self.norm(self.conv_dw(x))))
291
+ return x
292
+
293
+
294
+ class HorNetSubBlock(HorBlock):
295
+ """A block of HorNet."""
296
+
297
+ def __init__(self, dim, mlp_ratio=4., drop_path=0.1, init_value=1e-6):
298
+ super().__init__(dim, mlp_ratio=mlp_ratio, drop_path=drop_path, init_value=init_value)
299
+ self.apply(self._init_weights)
300
+
301
+ @torch.jit.ignore
302
+ def no_weight_decay(self):
303
+ return {'gamma1', 'gamma2'}
304
+
305
+ def _init_weights(self, m):
306
+ if isinstance(m, nn.Linear):
307
+ trunc_normal_(m.weight, std=.02)
308
+ if isinstance(m, nn.Linear) and m.bias is not None:
309
+ nn.init.constant_(m.bias, 0)
310
+ elif isinstance(m, nn.LayerNorm):
311
+ nn.init.constant_(m.bias, 0)
312
+ nn.init.constant_(m.weight, 1.0)
313
+ elif isinstance(m, nn.Conv2d):
314
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
315
+ fan_out //= m.groups
316
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
317
+ if m.bias is not None:
318
+ m.bias.data.zero_()
319
+
320
+
321
+ class MLPMixerSubBlock(MixerBlock):
322
+ """A block of MLP-Mixer."""
323
+
324
+ def __init__(self, dim, input_resolution=None, mlp_ratio=4., drop=0., drop_path=0.1):
325
+ seq_len = input_resolution[0] * input_resolution[1]
326
+ super().__init__(dim, seq_len=seq_len,
327
+ mlp_ratio=(0.5, mlp_ratio), drop_path=drop_path, drop=drop)
328
+ self.apply(self._init_weights)
329
+
330
+ def _init_weights(self, m):
331
+ if isinstance(m, nn.Linear):
332
+ trunc_normal_(m.weight, std=.02)
333
+ if isinstance(m, nn.Linear) and m.bias is not None:
334
+ nn.init.constant_(m.bias, 0)
335
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
336
+ nn.init.constant_(m.bias, 0)
337
+ nn.init.constant_(m.weight, 1.0)
338
+
339
+ @torch.jit.ignore
340
+ def no_weight_decay(self):
341
+ return dict()
342
+
343
+ def forward(self, x):
344
+ B, C, H, W = x.shape
345
+ x = x.flatten(2).transpose(1, 2)
346
+ x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
347
+ x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
348
+ return x.reshape(B, H, W, C).permute(0, 3, 1, 2)
349
+
350
+
351
+ class MogaSubBlock(nn.Module):
352
+ """A block of MogaNet."""
353
+
354
+ def __init__(self, embed_dims, mlp_ratio=4., drop_rate=0., drop_path_rate=0., init_value=1e-5,
355
+ attn_dw_dilation=[1, 2, 3], attn_channel_split=[1, 3, 4]):
356
+ super(MogaSubBlock, self).__init__()
357
+ self.out_channels = embed_dims
358
+ # spatial attention
359
+ self.norm1 = nn.BatchNorm2d(embed_dims)
360
+ self.attn = MultiOrderGatedAggregation(
361
+ embed_dims, attn_dw_dilation=attn_dw_dilation, attn_channel_split=attn_channel_split)
362
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
363
+ # channel MLP
364
+ self.norm2 = nn.BatchNorm2d(embed_dims)
365
+ mlp_hidden_dims = int(embed_dims * mlp_ratio)
366
+ self.mlp = ChannelAggregationFFN(
367
+ embed_dims=embed_dims, mlp_hidden_dims=mlp_hidden_dims, ffn_drop=drop_rate)
368
+ # init layer scale
369
+ self.layer_scale_1 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
370
+ self.layer_scale_2 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
371
+
372
+ self.apply(self._init_weights)
373
+
374
+ def _init_weights(self, m):
375
+ if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
376
+ nn.init.constant_(m.bias, 0)
377
+ nn.init.constant_(m.weight, 1.0)
378
+ elif isinstance(m, nn.Conv2d):
379
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
380
+ fan_out //= m.groups
381
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
382
+ if m.bias is not None:
383
+ m.bias.data.zero_()
384
+
385
+ @torch.jit.ignore
386
+ def no_weight_decay(self):
387
+ return {'layer_scale_1', 'layer_scale_2', 'sigma'}
388
+
389
+ def forward(self, x):
390
+ x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x)))
391
+ x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x)))
392
+ return x
393
+
394
+
395
+ class PoolFormerSubBlock(PoolFormerBlock):
396
+ """A block of PoolFormer."""
397
+
398
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1):
399
+ super().__init__(dim, pool_size=3, mlp_ratio=mlp_ratio, drop_path=drop_path,
400
+ drop=drop, init_value=1e-5)
401
+ self.apply(self._init_weights)
402
+
403
+ @torch.jit.ignore
404
+ def no_weight_decay(self):
405
+ return {'layer_scale_1', 'layer_scale_2'}
406
+
407
+ def _init_weights(self, m):
408
+ if isinstance(m, nn.Linear):
409
+ trunc_normal_(m.weight, std=.02)
410
+ if isinstance(m, nn.Linear) and m.bias is not None:
411
+ nn.init.constant_(m.bias, 0)
412
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
413
+ nn.init.constant_(m.bias, 0)
414
+ nn.init.constant_(m.weight, 1.0)
415
+
416
+
417
+ class SwinSubBlock(SwinTransformerBlock):
418
+ """A block of Swin Transformer."""
419
+
420
+ def __init__(self, dim, input_resolution=None, layer_i=0, mlp_ratio=4., drop=0., drop_path=0.1):
421
+ window_size = 7 if input_resolution[0] % 7 == 0 else max(4, input_resolution[0] // 16)
422
+ window_size = min(8, window_size)
423
+ shift_size = 0 if (layer_i % 2 == 0) else window_size // 2
424
+ super().__init__(dim, input_resolution, num_heads=8, window_size=window_size,
425
+ shift_size=shift_size, mlp_ratio=mlp_ratio,
426
+ drop_path=drop_path, attn_drop=drop, proj_drop=drop, qkv_bias=True)
427
+ self.apply(self._init_weights)
428
+
429
+ def _init_weights(self, m):
430
+ if isinstance(m, nn.Linear):
431
+ trunc_normal_(m.weight, std=.02)
432
+ if isinstance(m, nn.Linear) and m.bias is not None:
433
+ nn.init.constant_(m.bias, 0)
434
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
435
+ nn.init.constant_(m.bias, 0)
436
+ nn.init.constant_(m.weight, 1.0)
437
+
438
+ @torch.jit.ignore
439
+ def no_weight_decay(self):
440
+ return {}
441
+
442
+ def forward(self, x):
443
+ B, C, H, W = x.shape
444
+ x = x.flatten(2).transpose(1, 2)
445
+ x = self.norm1(x)
446
+ x = x.view(B, H, W, C)
447
+ x = super().forward(x)
448
+
449
+ return x.reshape(B, H, W, C).permute(0, 3, 1, 2)
450
+
451
+
452
+ def UniformerSubBlock(embed_dims, mlp_ratio=4., drop=0., drop_path=0.,
453
+ init_value=1e-6, block_type='Conv'):
454
+ """Build a block of Uniformer."""
455
+
456
+ assert block_type in ['Conv', 'MHSA']
457
+ if block_type == 'Conv':
458
+ return CBlock(dim=embed_dims, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
459
+ else:
460
+ return SABlock(dim=embed_dims, num_heads=8, mlp_ratio=mlp_ratio, qkv_bias=True,
461
+ drop=drop, drop_path=drop_path, init_value=init_value)
462
+
463
+
464
+ class VANSubBlock(VANBlock):
465
+ """A block of VAN."""
466
+
467
+ def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., init_value=1e-2, act_layer=nn.GELU):
468
+ super().__init__(dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path,
469
+ init_value=init_value, act_layer=act_layer)
470
+ self.apply(self._init_weights)
471
+
472
+ @torch.jit.ignore
473
+ def no_weight_decay(self):
474
+ return {'layer_scale_1', 'layer_scale_2'}
475
+
476
+ def _init_weights(self, m):
477
+ if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
478
+ nn.init.constant_(m.bias, 0)
479
+ nn.init.constant_(m.weight, 1.0)
480
+ elif isinstance(m, nn.Conv2d):
481
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
482
+ fan_out //= m.groups
483
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
484
+ if m.bias is not None:
485
+ m.bias.data.zero_()
486
+
487
+
488
+ class ViTSubBlock(ViTBlock):
489
+ """A block of Vision Transformer."""
490
+
491
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1):
492
+ super().__init__(dim=dim, num_heads=8, mlp_ratio=mlp_ratio, qkv_bias=True,
493
+ attn_drop=drop, proj_drop=0, drop_path=drop_path, act_layer=nn.GELU, norm_layer=nn.LayerNorm)
494
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
495
+ self.apply(self._init_weights)
496
+
497
+ def _init_weights(self, m):
498
+ if isinstance(m, nn.Linear):
499
+ trunc_normal_(m.weight, std=.02)
500
+ if isinstance(m, nn.Linear) and m.bias is not None:
501
+ nn.init.constant_(m.bias, 0)
502
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
503
+ nn.init.constant_(m.bias, 0)
504
+ nn.init.constant_(m.weight, 1.0)
505
+
506
+ @torch.jit.ignore
507
+ def no_weight_decay(self):
508
+ return {}
509
+
510
+ def forward(self, x):
511
+ B, C, H, W = x.shape
512
+ x = x.flatten(2).transpose(1, 2)
513
+ x = x + self.drop_path(self.attn(self.norm1(x)))
514
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
515
+ return x.reshape(B, H, W, C).permute(0, 3, 1, 2)
516
+
517
+
518
+ class TemporalAttention(nn.Module):
519
+ """A Temporal Attention block for Temporal Attention Unit"""
520
+
521
+ def __init__(self, d_model, kernel_size=21, attn_shortcut=True):
522
+ super().__init__()
523
+
524
+ self.proj_1 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
525
+ self.activation = nn.GELU() # GELU
526
+ self.spatial_gating_unit = TemporalAttentionModule(d_model, kernel_size)
527
+ self.proj_2 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
528
+ self.attn_shortcut = attn_shortcut
529
+
530
+ def forward(self, x):
531
+ if self.attn_shortcut:
532
+ shortcut = x.clone()
533
+ x = self.proj_1(x)
534
+ x = self.activation(x)
535
+ x = self.spatial_gating_unit(x)
536
+ x = self.proj_2(x)
537
+ if self.attn_shortcut:
538
+ x = x + shortcut
539
+ return x
540
+
541
+
542
+ class TemporalAttentionModule(nn.Module):
543
+ """Large Kernel Attention for SimVP"""
544
+
545
+ def __init__(self, dim, kernel_size, dilation=3, reduction=16):
546
+ super().__init__()
547
+ d_k = 2 * dilation - 1
548
+ d_p = (d_k - 1) // 2
549
+ dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1)
550
+ dd_p = (dilation * (dd_k - 1) // 2)
551
+
552
+ self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim)
553
+ self.conv_spatial = nn.Conv2d(
554
+ dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation)
555
+ self.conv1 = nn.Conv2d(dim, dim, 1)
556
+
557
+ self.reduction = max(dim // reduction, 4)
558
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
559
+ self.fc = nn.Sequential(
560
+ nn.Linear(dim, dim // self.reduction, bias=False), # reduction
561
+ nn.ReLU(True),
562
+ nn.Linear(dim // self.reduction, dim, bias=False), # expansion
563
+ nn.Sigmoid()
564
+ )
565
+
566
+ def forward(self, x):
567
+ u = x.clone()
568
+ attn = self.conv0(x) # depth-wise conv
569
+ attn = self.conv_spatial(attn) # depth-wise dilation convolution
570
+ f_x = self.conv1(attn) # 1x1 conv
571
+ # append a se operation
572
+ b, c, _, _ = x.size()
573
+ se_atten = self.avg_pool(x).view(b, c)
574
+ se_atten = self.fc(se_atten).view(b, c, 1, 1)
575
+ return se_atten * f_x * u
576
+
577
+
578
+ class TAUSubBlock(GASubBlock):
579
+ """A TAUBlock (tau) for Temporal Attention Unit"""
580
+
581
+ def __init__(self, dim, kernel_size=21, mlp_ratio=4.,
582
+ drop=0., drop_path=0.1, init_value=1e-2, act_layer=nn.GELU):
583
+ super().__init__(dim=dim, kernel_size=kernel_size, mlp_ratio=mlp_ratio,
584
+ drop=drop, drop_path=drop_path, init_value=init_value, act_layer=act_layer)
585
+
586
+ self.attn = TemporalAttention(dim, kernel_size)
utilpack/swinlstm_modules.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.swin_transformer import SwinTransformerBlock, window_reverse, PatchEmbed, PatchMerging, window_partition
4
+ from timm.layers import to_2tuple
5
+
6
+ class SwinLSTMCell(nn.Module):
7
+ def __init__(self, dim, input_resolution, num_heads, window_size, depth,
8
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
9
+ drop_path=0., norm_layer=nn.LayerNorm, flag=None):
10
+ """
11
+ Args:
12
+ flag: 0 UpSample 1 DownSample 2 STconvert
13
+ """
14
+ super(SwinLSTMCell, self).__init__()
15
+
16
+ self.STBs = nn.ModuleList(
17
+ STB(i, dim=dim, input_resolution=input_resolution, depth=depth,
18
+ num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
19
+ qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
20
+ drop_path=drop_path, norm_layer=norm_layer, flag=flag)
21
+ for i in range(depth))
22
+
23
+ def forward(self, xt, hidden_states):
24
+ """
25
+ Args:
26
+ xt: input for t period
27
+ hidden_states: [hx, cx] hidden_states for t-1 period
28
+ """
29
+ if hidden_states is None:
30
+ B, L, C = xt.shape
31
+ hx = torch.zeros(B, L, C).to(xt.device)
32
+ cx = torch.zeros(B, L, C).to(xt.device)
33
+
34
+ else:
35
+ hx, cx = hidden_states
36
+
37
+ outputs = []
38
+ for index, layer in enumerate(self.STBs):
39
+ if index == 0:
40
+ x = layer(xt, hx)
41
+ outputs.append(x)
42
+ else:
43
+ if index % 2 == 0:
44
+ x = layer(outputs[-1], xt)
45
+ outputs.append(x)
46
+ if index % 2 == 1:
47
+ x = layer(outputs[-1], None)
48
+ outputs.append(x)
49
+
50
+ o_t = outputs[-1]
51
+ Ft = torch.sigmoid(o_t)
52
+
53
+ cell = torch.tanh(o_t)
54
+
55
+ Ct = Ft * (cx + cell)
56
+ Ht = Ft * torch.tanh(Ct)
57
+
58
+ return Ht, (Ht, Ct)
59
+
60
+ class STB(SwinTransformerBlock):
61
+ def __init__(self, index, dim, input_resolution, depth, num_heads, window_size,
62
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
63
+ drop_path=0., norm_layer=nn.LayerNorm, flag=None):
64
+ if flag == 0:
65
+ drop_path = drop_path[depth - index - 1]
66
+ elif flag == 1:
67
+ drop_path = drop_path[index]
68
+ elif flag == 2:
69
+ drop_path = drop_path
70
+ super(STB, self).__init__(dim=dim, input_resolution=input_resolution,
71
+ num_heads=num_heads, window_size=window_size,
72
+ shift_size=0 if (index % 2 == 0) else window_size // 2,
73
+ mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
74
+ drop=drop, attn_drop=attn_drop,
75
+ drop_path=drop_path,
76
+ norm_layer=norm_layer)
77
+ self.red = nn.Linear(2 * dim, dim)
78
+
79
+ def forward(self, x, hx=None):
80
+ H, W = self.input_resolution
81
+ B, L, C = x.shape
82
+ assert L == H * W, "input feature has wrong size"
83
+
84
+ shortcut = x
85
+ x = self.norm1(x)
86
+ if hx is not None:
87
+ hx = self.norm1(hx)
88
+ x = torch.cat((x, hx), -1)
89
+ x = self.red(x)
90
+ x = x.view(B, H, W, C)
91
+
92
+ # cyclic shift
93
+ if self.shift_size > 0:
94
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
95
+ else:
96
+ shifted_x = x
97
+
98
+ # partition windows
99
+ x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C
100
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C
101
+
102
+ # W-MSA/SW-MSA
103
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C
104
+
105
+ # merge windows
106
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
107
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
108
+
109
+ # reverse cyclic shift
110
+ if self.shift_size > 0:
111
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
112
+ else:
113
+ x = shifted_x
114
+ x = x.view(B, H * W, C)
115
+
116
+ # FFN
117
+ x = shortcut + self.drop_path(x)
118
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
119
+
120
+ return x
121
+
122
+ class PatchInflated(nn.Module):
123
+ r""" Tensor to Patch Inflating
124
+
125
+ Args:
126
+ in_chans (int): Number of input image channels.
127
+ embed_dim (int): Number of linear projection output channels.
128
+ input_resolution (tuple[int]): Input resulotion.
129
+ """
130
+
131
+ def __init__(self, in_chans, embed_dim, input_resolution, stride=2, padding=1, output_padding=1):
132
+ super(PatchInflated, self).__init__()
133
+
134
+ stride = to_2tuple(stride)
135
+ padding = to_2tuple(padding)
136
+ output_padding = to_2tuple(output_padding)
137
+ self.input_resolution = input_resolution
138
+
139
+ self.Conv = nn.ConvTranspose2d(in_channels=embed_dim, out_channels=in_chans, kernel_size=(3, 3),
140
+ stride=stride, padding=padding, output_padding=output_padding)
141
+
142
+ def forward(self, x):
143
+ H, W = self.input_resolution
144
+ B, L, C = x.shape
145
+ assert L == H * W, "input feature has wrong size"
146
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
147
+
148
+ x = x.view(B, H, W, C)
149
+ x = x.permute(0, 3, 1, 2)
150
+ x = self.Conv(x)
151
+
152
+ return x
153
+
154
+ class PatchExpanding(nn.Module):
155
+ r""" Patch Expanding Layer.
156
+
157
+ Args:
158
+ input_resolution (tuple[int]): Resolution of input feature.
159
+ dim (int): Number of input channels.
160
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
161
+ """
162
+
163
+ def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
164
+ super(PatchExpanding, self).__init__()
165
+ self.input_resolution = input_resolution
166
+ self.dim = dim
167
+ self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
168
+ self.norm = norm_layer(dim // dim_scale)
169
+
170
+ def forward(self, x):
171
+ H, W = self.input_resolution
172
+ x = self.expand(x)
173
+ B, L, C = x.shape
174
+ assert L == H * W, "input feature has wrong size"
175
+
176
+ x = x.view(B, H, W, C)
177
+ x = x.reshape(B, H, W, 2, 2, C // 4)
178
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H * 2, W * 2, C // 4)
179
+ x = x.view(B, -1, C // 4)
180
+ x = self.norm(x)
181
+
182
+ return x
183
+
184
+ class UpSample(nn.Module):
185
+ def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_upsample, num_heads, window_size, mlp_ratio=4.,
186
+ qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
187
+ norm_layer=nn.LayerNorm, flag=0):
188
+ super(UpSample, self).__init__()
189
+
190
+ self.img_size = img_size
191
+ self.num_layers = len(depths_upsample)
192
+ self.embed_dim = embed_dim
193
+ self.mlp_ratio = mlp_ratio
194
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=nn.LayerNorm)
195
+ patches_resolution = self.patch_embed.grid_size
196
+ self.Unembed = PatchInflated(in_chans=in_chans, embed_dim=embed_dim, input_resolution=patches_resolution)
197
+
198
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_upsample))]
199
+
200
+ self.layers = nn.ModuleList()
201
+ self.upsample = nn.ModuleList()
202
+
203
+ for i_layer in range(self.num_layers):
204
+ resolution1 = (patches_resolution[0] // (2 ** (self.num_layers - i_layer)))
205
+ resolution2 = (patches_resolution[1] // (2 ** (self.num_layers - i_layer)))
206
+
207
+ dimension = int(embed_dim * 2 ** (self.num_layers - i_layer))
208
+ upsample = PatchExpanding(input_resolution=(resolution1, resolution2), dim=dimension)
209
+
210
+ layer = SwinLSTMCell(dim=dimension, input_resolution=(resolution1, resolution2),
211
+ depth=depths_upsample[(self.num_layers - 1 - i_layer)],
212
+ num_heads=num_heads[(self.num_layers - 1 - i_layer)],
213
+ window_size=window_size,
214
+ mlp_ratio=self.mlp_ratio,
215
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
216
+ drop=drop_rate, attn_drop=attn_drop_rate,
217
+ drop_path=dpr[sum(depths_upsample[:(self.num_layers - 1 - i_layer)]):
218
+ sum(depths_upsample[:(self.num_layers - 1 - i_layer) + 1])],
219
+ norm_layer=norm_layer, flag=flag)
220
+
221
+ self.layers.append(layer)
222
+ self.upsample.append(upsample)
223
+
224
+ def forward(self, x, y):
225
+ hidden_states_up = []
226
+
227
+ for index, layer in enumerate(self.layers):
228
+ x, hidden_state = layer(x, y[index])
229
+ x = self.upsample[index](x)
230
+ hidden_states_up.append(hidden_state)
231
+
232
+ x = torch.sigmoid(self.Unembed(x))
233
+
234
+ return hidden_states_up, x
235
+
236
+ class DownSample(nn.Module):
237
+ def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_downsample, num_heads, window_size,
238
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
239
+ norm_layer=nn.LayerNorm, flag=1):
240
+ super(DownSample, self).__init__()
241
+
242
+ self.num_layers = len(depths_downsample)
243
+ self.embed_dim = embed_dim
244
+ self.mlp_ratio = mlp_ratio
245
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=nn.LayerNorm)
246
+ patches_resolution = self.patch_embed.grid_size
247
+
248
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_downsample))]
249
+
250
+ self.layers = nn.ModuleList()
251
+ self.downsample = nn.ModuleList()
252
+
253
+ for i_layer in range(self.num_layers):
254
+ downsample = PatchMerging(input_resolution=(patches_resolution[0] // (2 ** i_layer),
255
+ patches_resolution[1] // (2 ** i_layer)),
256
+ dim=int(embed_dim * 2 ** i_layer))
257
+
258
+ layer = SwinLSTMCell(dim=int(embed_dim * 2 ** i_layer),
259
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
260
+ patches_resolution[1] // (2 ** i_layer)),
261
+ depth=depths_downsample[i_layer],
262
+ num_heads=num_heads[i_layer],
263
+ window_size=window_size,
264
+ mlp_ratio=self.mlp_ratio,
265
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
266
+ drop=drop_rate, attn_drop=attn_drop_rate,
267
+ drop_path=dpr[sum(depths_downsample[:i_layer]):sum(depths_downsample[:i_layer + 1])],
268
+ norm_layer=norm_layer, flag=flag)
269
+
270
+ self.layers.append(layer)
271
+ self.downsample.append(downsample)
272
+
273
+ def forward(self, x, y):
274
+
275
+ x = self.patch_embed(x)
276
+
277
+ hidden_states_down = []
278
+
279
+ for index, layer in enumerate(self.layers):
280
+ x, hidden_state = layer(x, y[index])
281
+ x = self.downsample[index](x)
282
+ hidden_states_down.append(hidden_state)
283
+
284
+ return hidden_states_down, x
285
+
286
+ class STconvert(nn.Module):
287
+ def __init__(self, img_size, patch_size, in_chans, embed_dim, depths, num_heads,
288
+ window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0.,
289
+ attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, flag=2):
290
+ super(STconvert, self).__init__()
291
+
292
+ self.embed_dim = embed_dim
293
+ self.mlp_ratio = mlp_ratio
294
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,
295
+ in_chans=in_chans, embed_dim=embed_dim,
296
+ norm_layer=norm_layer)
297
+ patches_resolution = self.patch_embed.grid_size
298
+
299
+ self.patch_inflated = PatchInflated(in_chans=in_chans, embed_dim=embed_dim,
300
+ input_resolution=patches_resolution)
301
+
302
+ self.layer = SwinLSTMCell(dim=embed_dim,
303
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
304
+ depth=depths, num_heads=num_heads,
305
+ window_size=window_size, mlp_ratio=mlp_ratio,
306
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
307
+ drop=drop_rate, attn_drop=attn_drop_rate,
308
+ drop_path=drop_path_rate, norm_layer=norm_layer,
309
+ flag=flag)
310
+ def forward(self, x, h=None):
311
+ x = self.patch_embed(x)
312
+
313
+ x, hidden_state = self.layer(x, h)
314
+
315
+ x = torch.sigmoid(self.patch_inflated(x))
316
+
317
+ return x, hidden_state
utilpack/wast_modules.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, pywt
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ from functools import partial
5
+ from itertools import accumulate
6
+ from timm.layers import DropPath, activations
7
+ from timm.models._efficientnet_blocks import SqueezeExcite, InvertedResidual
8
+
9
+ # version adaptation for PyTorch > 1.7.1
10
+ IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.'))) > (1, 7, 1)
11
+ if IS_HIGH_VERSION:
12
+ import torch.fft
13
+
14
+ class HighFocalFrequencyLoss(nn.Module):
15
+ """ Example:
16
+ fake = torch.randn(4, 3, 128, 64)
17
+ real = torch.randn(4, 3, 128, 64)
18
+ hffl = HighFocalFrequencyLoss()
19
+
20
+ loss = hffl(fake, real)
21
+ print(loss)
22
+ """
23
+
24
+ def __init__(self, loss_weight=0.001, level=1, tau=0.1, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=True, batch_matrix=False):
25
+ super(HighFocalFrequencyLoss, self).__init__()
26
+ self.loss_weight = loss_weight
27
+ self.alpha = alpha
28
+ self.patch_factor = patch_factor
29
+ self.ave_spectrum = ave_spectrum
30
+ self.log_matrix = log_matrix
31
+ self.batch_matrix = batch_matrix
32
+ self.level = level
33
+ self.tau = tau
34
+ self.DWT = WaveletTransform2D().cuda()
35
+
36
+ def tensor2freq(self, x):
37
+ # crop image patches
38
+ patch_factor = self.patch_factor
39
+ _, _, h, w = x.shape
40
+ assert h % patch_factor == 0 and w % patch_factor == 0, (
41
+ 'Patch factor should be divisible by image height and width')
42
+ patch_list = []
43
+ patch_h = h // patch_factor
44
+ patch_w = w // patch_factor
45
+ for i in range(patch_factor):
46
+ for j in range(patch_factor):
47
+ patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
48
+
49
+ # stack to patch tensor
50
+ y = torch.stack(patch_list, 1)
51
+
52
+ # perform 2D DFT (real-to-complex, orthonormalization)
53
+ if IS_HIGH_VERSION:
54
+ freq = torch.fft.fft2(y, norm='ortho')
55
+ freq = torch.stack([freq.real, freq.imag], -1)
56
+ else:
57
+ freq = torch.rfft(y, 2, onesided=False, normalized=True)
58
+ return freq
59
+
60
+ def build_freq_mask(self, shape):
61
+ H, W = shape[-2:]
62
+ radius = self.tau * max(H, W)
63
+ Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
64
+
65
+ mask = torch.ones_like(X, dtype=torch.float32).cuda()
66
+
67
+ centers = [(0, 0), (0, W - 1), (H - 1, 0), (H - 1, W - 1)]
68
+
69
+ for center in centers:
70
+ distance = torch.sqrt((X - center[1]) ** 2 + (Y - center[0]) ** 2)
71
+ mask[distance <= radius] = 0
72
+ return mask
73
+
74
+ def loss_formulation(self, recon_freq, real_freq, matrix=None):
75
+ # spectrum weight matrix
76
+ if matrix is not None:
77
+ # if the matrix is predefined
78
+ weight_matrix = matrix.detach()
79
+ else:
80
+ # if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance
81
+ matrix_tmp = (recon_freq - real_freq) ** 2
82
+ matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
83
+
84
+ # whether to adjust the spectrum weight matrix by logarithm
85
+ if self.log_matrix:
86
+ matrix_tmp = torch.log(matrix_tmp + 1.0)
87
+
88
+ # whether to calculate the spectrum weight matrix using batch-based statistics
89
+ if self.batch_matrix:
90
+ matrix_tmp = matrix_tmp / matrix_tmp.max()
91
+ else:
92
+ matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
93
+
94
+ matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
95
+ matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
96
+ weight_matrix = matrix_tmp.clone().detach()
97
+
98
+ assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
99
+ 'The values of spectrum weight matrix should be in the range [0, 1], '
100
+ 'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
101
+
102
+ # frequency distance using (squared) Euclidean distance
103
+ tmp = (recon_freq - real_freq) ** 2
104
+ freq_distance = tmp[..., 0] + tmp[..., 1]
105
+
106
+ # dynamic spectrum weighting (Hadamard product)
107
+ mask = self.build_freq_mask(weight_matrix.shape)
108
+ loss = weight_matrix * freq_distance * mask
109
+ return torch.mean(loss)
110
+
111
+ def frequency_loss(self, pred, target, matrix=None):
112
+ """Forward function to calculate focal frequency loss.
113
+
114
+ Args:
115
+ pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor.
116
+ target (torch.Tensor): of shape (N, C, H, W). Target tensor.
117
+ matrix (torch.Tensor, optional): Element-wise spectrum weight matrix.
118
+ Default: None (If set to None: calculated online, dynamic).
119
+ """
120
+ pred_freq = self.tensor2freq(pred)
121
+ target_freq = self.tensor2freq(target)
122
+
123
+ # whether to use minibatch average spectrum
124
+ if self.ave_spectrum:
125
+ pred_freq = torch.mean(pred_freq, 0, keepdim=True)
126
+ target_freq = torch.mean(target_freq, 0, keepdim=True)
127
+
128
+ return self.loss_formulation(pred_freq, target_freq, matrix)
129
+
130
+ def forward(self, pred, target, matrix=None, **kwargs):
131
+ pred = rearrange(pred, 'b t c h w -> (b t) c h w') if kwargs["reshape"] is True else pred
132
+ target = rearrange(target, 'b t c h w -> (b t) c h w') if kwargs["reshape"] is True else target
133
+
134
+ loss = 0
135
+ for level in range(self.level):
136
+ pred, _, _, _ = self.DWT(pred)
137
+ target, _, _, _ = self.DWT(target)
138
+ loss += self.frequency_loss(pred, target, matrix)
139
+ return loss * self.loss_weight
140
+
141
+
142
+ class WaveletTransform2D(nn.Module):
143
+ """Compute a two-dimensional wavelet transform.
144
+ loss = nn.MSELoss()
145
+ data = torch.rand(1, 3, 128, 256)
146
+ DWT = WaveletTransform2D()
147
+ IDWT = WaveletTransform2D(inverse=True)
148
+
149
+ LL, LH, HL, HH = DWT(data)
150
+ recdata = IDWT([LL, LH, HL, HH])
151
+ print(loss(data, recdata))
152
+ """
153
+ def __init__(self, inverse=False, wavelet="haar", mode="constant"):
154
+ super(WaveletTransform2D, self).__init__()
155
+ self.mode = mode
156
+ wavelet = pywt.Wavelet(wavelet)
157
+
158
+ if isinstance(wavelet, tuple):
159
+ dec_lo, dec_hi, rec_lo, rec_hi = wavelet
160
+ else:
161
+ dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank
162
+
163
+ self.inverse = inverse
164
+ if inverse is False:
165
+ dec_lo = torch.tensor(dec_lo).flip(-1).unsqueeze(0)
166
+ dec_hi = torch.tensor(dec_hi).flip(-1).unsqueeze(0)
167
+ self.build_filters(dec_lo, dec_hi)
168
+ else:
169
+ rec_lo = torch.tensor(rec_lo).unsqueeze(0)
170
+ rec_hi = torch.tensor(rec_hi).unsqueeze(0)
171
+ self.build_filters(rec_lo, rec_hi)
172
+
173
+ def build_filters(self, lo, hi):
174
+ # construct 2d filter
175
+ self.dim_size = lo.shape[-1]
176
+ ll = self.outer(lo, lo)
177
+ lh = self.outer(hi, lo)
178
+ hl = self.outer(lo, hi)
179
+ hh = self.outer(hi, hi)
180
+ filters = torch.stack([ll, lh, hl, hh],dim=0)
181
+ filters = filters.unsqueeze(1)
182
+ self.register_buffer('filters', filters) # [4, 1, height, width]
183
+
184
+ def outer(self, a: torch.Tensor, b: torch.Tensor):
185
+ """Torch implementation of numpy's outer for 1d vectors."""
186
+ a_flat = torch.reshape(a, [-1])
187
+ b_flat = torch.reshape(b, [-1])
188
+ a_mul = torch.unsqueeze(a_flat, dim=-1)
189
+ b_mul = torch.unsqueeze(b_flat, dim=0)
190
+ return a_mul * b_mul
191
+
192
+ def get_pad(self, data_len: int, filter_len: int):
193
+ padr = (2 * filter_len - 3) // 2
194
+ padl = (2 * filter_len - 3) // 2
195
+ # pad to even singal length.
196
+ if data_len % 2 != 0:
197
+ padr += 1
198
+ return padr, padl
199
+
200
+ def adaptive_pad(self, data):
201
+ padb, padt = self.get_pad(data.shape[-2], self.dim_size)
202
+ padr, padl = self.get_pad(data.shape[-1], self.dim_size)
203
+
204
+ data_pad = torch.nn.functional.pad(data, [padl, padr, padt, padb], mode=self.mode)
205
+ return data_pad
206
+
207
+ def forward(self, data):
208
+ if self.inverse is False:
209
+ b, c, h, w = data.shape
210
+ dec_res = []
211
+ data = self.adaptive_pad(data)
212
+ for filter in self.filters:
213
+ dec_res.append(torch.nn.functional.conv2d(data, filter.repeat(c, 1, 1, 1), stride=2, groups=c))
214
+ return dec_res
215
+ else:
216
+ b, c, h, w = data[0].shape
217
+ data = torch.stack(data, dim=2).reshape(b, -1, h, w)
218
+ rec_res = torch.nn.functional.conv_transpose2d(data, self.filters.repeat(c, 1, 1, 1), stride=2, groups=c)
219
+ return rec_res
220
+
221
+
222
+ class WaveletTransform3D(nn.Module):
223
+ """Compute a three-dimensional wavelet transform.
224
+ Example:
225
+ loss = nn.MSELoss()
226
+ data = torch.rand(1, 3, 10, 128, 256)
227
+ DWT = WaveletTransform3D()
228
+ IDWT = WaveletTransform3D(inverse=True)
229
+
230
+ LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = DWT(data)
231
+ recdata = IDWT([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH])
232
+ print(loss(data, recdata))
233
+
234
+ LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = DWT_3D(data)
235
+ recdata = IDWT_3D(LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH)
236
+ print(loss(data, recdata))
237
+ """
238
+ def __init__(self, inverse=False, wavelet="haar", mode="constant"):
239
+ super(WaveletTransform3D, self).__init__()
240
+ self.mode = mode
241
+ wavelet = pywt.Wavelet(wavelet)
242
+
243
+ if isinstance(wavelet, tuple):
244
+ dec_lo, dec_hi, rec_lo, rec_hi = wavelet
245
+ else:
246
+ dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank
247
+
248
+ self.inverse = inverse
249
+ if inverse is False:
250
+ dec_lo = torch.tensor(dec_lo).flip(-1).unsqueeze(0)
251
+ dec_hi = torch.tensor(dec_hi).flip(-1).unsqueeze(0)
252
+ self.build_filters(dec_lo, dec_hi)
253
+ else:
254
+ rec_lo = torch.tensor(rec_lo).unsqueeze(0)
255
+ rec_hi = torch.tensor(rec_hi).unsqueeze(0)
256
+ self.build_filters(rec_lo, rec_hi)
257
+
258
+ def build_filters(self, lo, hi):
259
+ # construct 3d filter
260
+ self.dim_size = lo.shape[-1]
261
+ size = [self.dim_size] * 3
262
+ lll = self.outer(lo, self.outer(lo, lo)).reshape(size)
263
+ llh = self.outer(lo, self.outer(lo, hi)).reshape(size)
264
+ lhl = self.outer(lo, self.outer(hi, lo)).reshape(size)
265
+ lhh = self.outer(lo, self.outer(hi, hi)).reshape(size)
266
+ hll = self.outer(hi, self.outer(lo, lo)).reshape(size)
267
+ hlh = self.outer(hi, self.outer(lo, hi)).reshape(size)
268
+ hhl = self.outer(hi, self.outer(hi, lo)).reshape(size)
269
+ hhh = self.outer(hi, self.outer(hi, hi)).reshape(size)
270
+ filters = torch.stack([lll, llh, lhl, lhh, hll, hlh, hhl, hhh], dim=0)
271
+ filters = filters.unsqueeze(1)
272
+ self.register_buffer('filters', filters) # [8, 1, length, height, width]
273
+
274
+ def outer(self, a: torch.Tensor, b: torch.Tensor):
275
+ """Torch implementation of numpy's outer for 1d vectors."""
276
+ a_flat = torch.reshape(a, [-1])
277
+ b_flat = torch.reshape(b, [-1])
278
+ a_mul = torch.unsqueeze(a_flat, dim=-1)
279
+ b_mul = torch.unsqueeze(b_flat, dim=0)
280
+ return a_mul * b_mul
281
+
282
+ def get_pad(self, data_len: int, filter_len: int):
283
+ padr = (2 * filter_len - 3) // 2
284
+ padl = (2 * filter_len - 3) // 2
285
+ # pad to even singal length.
286
+ if data_len % 2 != 0:
287
+ padr += 1
288
+ return padr, padl
289
+
290
+ def adaptive_pad(self, data):
291
+ pad_back, pad_front = self.get_pad(data.shape[-3], self.dim_size)
292
+ pad_bottom, pad_top = self.get_pad(data.shape[-2], self.dim_size)
293
+ pad_right, pad_left = self.get_pad(data.shape[-1], self.dim_size)
294
+ data_pad = torch.nn.functional.pad(
295
+ data, [pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back], mode=self.mode)
296
+ return data_pad
297
+
298
+ def forward(self, data):
299
+ if self.inverse is False:
300
+ b, c, t, h, w = data.shape
301
+ dec_res = []
302
+ data = self.adaptive_pad(data)
303
+ for filter in self.filters:
304
+ dec_res.append(torch.nn.functional.conv3d(data, filter.repeat(c, 1, 1, 1, 1), stride=2, groups=c))
305
+ return dec_res
306
+ else:
307
+ b, c, t, h, w = data[0].shape
308
+ data = torch.stack(data, dim=2).reshape(b, -1, t, h, w)
309
+ rec_res = torch.nn.functional.conv_transpose3d(data, self.filters.repeat(c, 1, 1, 1, 1), stride=2, groups=c)
310
+ return rec_res
311
+
312
+
313
+ class FrequencyAttention(nn.Module):
314
+ def __init__(self, in_dim, out_dim, reduction=32):
315
+ super(FrequencyAttention, self).__init__()
316
+ self.avgpool_h = nn.AdaptiveAvgPool2d((None, 1))
317
+ self.avgpool_w = nn.AdaptiveAvgPool2d((1, None))
318
+
319
+ hidden_dim = max(8, in_dim // reduction)
320
+
321
+ self.conv1 = nn.Conv2d(in_dim, hidden_dim, kernel_size=1, stride=1, padding=0)
322
+ self.bn1 = nn.BatchNorm2d(hidden_dim)
323
+ self.act = activations.HardSwish(inplace=True)
324
+
325
+ self.conv_h = nn.Conv2d(hidden_dim, out_dim, kernel_size=1, stride=1, padding=0)
326
+ self.conv_w = nn.Conv2d(hidden_dim, out_dim, kernel_size=1, stride=1, padding=0)
327
+
328
+ def forward(self, x):
329
+ identity = x
330
+
331
+ n, c, h, w = x.size()
332
+ x_h = self.avgpool_h(x) # b c h 1
333
+ x_w = self.avgpool_w(x).permute(0, 1, 3, 2) # b c w 1
334
+
335
+ y = torch.cat([x_h, x_w], dim=2) # b c (h+w) 1
336
+ y = self.conv1(y)
337
+ y = self.bn1(y)
338
+ y = self.act(y)
339
+
340
+ x_h, x_w = torch.split(y, [h, w], dim=2)
341
+ x_w = x_w.permute(0, 1, 3, 2)
342
+
343
+ a_h = self.conv_h(x_h).sigmoid()
344
+ a_w = self.conv_w(x_w).sigmoid()
345
+
346
+ out = identity * a_w * a_h
347
+
348
+ return out
349
+
350
+
351
+ class TF_AwareBlock(nn.Module):
352
+ def __init__(self, dim, mlp_ratio=4., drop=0., ls_init_value=1e-2, drop_path=0.1, large_kernel=51, small_kernel=5):
353
+ super().__init__()
354
+
355
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
356
+ self.norm1 = nn.BatchNorm2d(dim)
357
+ self.norm2 = nn.BatchNorm2d(dim)
358
+
359
+ self.lk1 = nn.Sequential(
360
+ nn.Conv2d(dim, dim, kernel_size=(large_kernel, 5), groups=dim, padding="same"),
361
+ nn.BatchNorm2d(dim)
362
+ )
363
+
364
+ self.lk2 = nn.Sequential(
365
+ nn.Conv2d(dim, dim, kernel_size=(5, large_kernel), groups=dim, padding="same"),
366
+ nn.BatchNorm2d(dim)
367
+ )
368
+
369
+ self.sk = nn.Sequential(
370
+ nn.Conv2d(dim, dim, kernel_size=(small_kernel, small_kernel), groups=dim, padding="same"),
371
+ nn.BatchNorm2d(dim)
372
+ )
373
+
374
+ self.low_frequency_attn = FrequencyAttention(in_dim=dim, out_dim=dim, reduction=4)
375
+ self.high_frequency_attn = FrequencyAttention(in_dim=dim, out_dim=dim, reduction=4)
376
+
377
+ self.temporal_mixer = InvertedResidual(in_chs=dim, out_chs=dim, dw_kernel_size=7, exp_ratio=mlp_ratio,
378
+ se_layer=partial(SqueezeExcite, rd_ratio=0.25), noskip=True)
379
+
380
+
381
+ self.layer_scale_1 = nn.Parameter(ls_init_value * torch.ones((dim)), requires_grad=True)
382
+ self.layer_scale_2 = nn.Parameter(ls_init_value * torch.ones((dim)), requires_grad=True)
383
+
384
+ @torch.jit.ignore
385
+ def no_weight_decay(self):
386
+ return {'layer_scale_1', 'layer_scale_2'}
387
+
388
+ def forward(self, x):
389
+ attn = self.norm1(x)
390
+ x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * (self.low_frequency_attn(self.lk1(attn) + self.lk2(attn)) + self.high_frequency_attn(self.sk(attn))))
391
+ x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.temporal_mixer(self.norm2(x)))
392
+ return x
393
+
394
+
395
+ class TF_AwareBlocks(nn.Module):
396
+ def __init__(self, dim, num_blocks, drop_path, use_bottleneck=None, use_hid=False, mlp_ratio=4., drop=0., ls_init_value=1e-2, large_kernel=51, small_kernel=5):
397
+ super().__init__()
398
+ assert len(drop_path) == num_blocks, "drop_path list doesn't match num_blocks"
399
+ self.use_hid = use_hid
400
+ self.use_bottleneck = use_bottleneck
401
+
402
+ blocks = []
403
+ for i in range(num_blocks):
404
+ block = TF_AwareBlock(dim, mlp_ratio, drop, ls_init_value, drop_path[i], large_kernel, small_kernel)
405
+ blocks.append(block)
406
+ self.blocks = nn.Sequential(*blocks)
407
+ self.concat_block = nn.Conv2d(dim * 2, dim, 3, 1, 1) if use_hid==True else None
408
+
409
+ self.DWT = WaveletTransform3D(inverse=False) if use_bottleneck == "decompose" else None
410
+ self.IDWT = WaveletTransform3D(inverse=True) if use_bottleneck == "decompose" else None
411
+
412
+ def forward(self, x, skip=None): # b, c ,t, h, w
413
+ if self.concat_block is not None and self.use_bottleneck is None:
414
+ b, c, t, h, w = x.shape
415
+ x = rearrange(x, 'b c t h w -> b (c t) h w')
416
+ x = self.concat_block(torch.cat([x, skip], dim=1))
417
+ x = self.blocks(x)
418
+ x = rearrange(x, 'b (c t) h w -> b c t h w', t=t)
419
+ return x
420
+ elif self.concat_block is None and self.use_bottleneck is None:
421
+ b, c, t, h, w = x.shape
422
+ x = rearrange(x, 'b c t h w -> b (c t) h w')
423
+ x = skip= self.blocks(x)
424
+ x = rearrange(x, 'b (c t) h w -> b c t h w', t=t)
425
+ return x, skip
426
+ elif self.use_bottleneck is not None:
427
+ LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = self.DWT(x) if self.use_bottleneck == "decompose" else [x, None, None, None, None, None, None, None]
428
+ b, c, t, h, w = LLL.shape
429
+ LLL = rearrange(LLL, 'b c t h w -> b (c t) h w')
430
+ LLL = self.blocks(LLL)
431
+ LLL = rearrange(LLL, 'b (c t) h w -> b c t h w', t=t)
432
+ x = self.IDWT([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH]) if self.use_bottleneck == "decompose" else LLL
433
+ return x
434
+
435
+
436
+
437
+ class Wavelet_3D_Embedding(nn.Module):
438
+ def __init__(self, in_dim, out_dim, emb_dim=None):
439
+ super().__init__()
440
+ emb_dim = in_dim if emb_dim==None else emb_dim
441
+ self.conv_0 = nn.Sequential(nn.Conv3d(in_dim, in_dim, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),),
442
+ nn.BatchNorm3d(in_dim),
443
+ nn.GELU(),)
444
+ self.conv_1 = nn.Sequential(nn.Conv3d(in_dim, out_dim, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),),
445
+ nn.BatchNorm3d(out_dim),
446
+ nn.GELU(),)
447
+
448
+ self.conv_emb = nn.Conv3d(emb_dim * 4, out_dim, kernel_size=(3, 3, 3),stride=(1, 1, 1),padding=(1, 1, 1),)
449
+
450
+ self.DWT = WaveletTransform3D(inverse=False)
451
+
452
+ def forward(self, x, x_emb=None):
453
+ # embedding branch
454
+ LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = self.DWT(x_emb)
455
+ lo_temp = torch.cat([LLL, LHL, HLL, HHL], dim=1)
456
+ hi_temp = torch.cat([LLH, LHH, HLH, HHH], dim=1)
457
+ x_emb = torch.cat([lo_temp, hi_temp], dim=2)
458
+ x_emb = self.conv_emb(x_emb)
459
+ # downsampling branch
460
+ x = self.conv_0(x)
461
+ LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = self.DWT(x)
462
+ spatio_lo_coeffs = torch.cat([LLL, LLH], dim=2)
463
+ spatio_hi_coeffs = torch.cat([LHL, LHH, HLL, HLH, HHL, HHH], dim=1)
464
+ x = self.conv_1(spatio_lo_coeffs)
465
+ return (x + x_emb), spatio_hi_coeffs
466
+
467
+
468
+ class Wavelet_3D_Reconstruction(nn.Module):
469
+ def __init__(self, in_dim, out_dim, hi_dim):
470
+ super().__init__()
471
+ self.conv_0 = nn.Sequential(nn.Conv3d(in_dim, out_dim, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),),
472
+ nn.BatchNorm3d(out_dim),
473
+ nn.GELU(),)
474
+
475
+ self.conv_hi = nn.Sequential(nn.Conv3d(int(hi_dim * 6), int(out_dim * 6), kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=6),
476
+ nn.BatchNorm3d(out_dim * 6),
477
+ nn.GELU(),)
478
+
479
+ self.IDWT = WaveletTransform3D(inverse=True)
480
+
481
+ def forward(self, x, skip_hi=None):
482
+ LLL, LLH = torch.chunk(self.conv_0(x), chunks=2, dim=2)
483
+ LHL, LHH, HLL, HLH, HHL, HHH = torch.chunk(self.conv_hi(skip_hi), chunks=6, dim=1)
484
+ x = self.IDWT([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH])
485
+ return x
486
+
487
+
488
+ class WaST_level1(nn.Module):
489
+ def __init__(self, in_shape, encoder_dim, block_list=[2, 2, 2], drop_path_rate=0.1, mlp_ratio=4., **kwargs):
490
+ super().__init__()
491
+ frame, in_dim, H, W = in_shape
492
+ self.block_list = block_list
493
+ dp_list = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.block_list))]
494
+ indexes = list(accumulate(block_list))
495
+ dp_list = [dp_list[start:end] for start, end in zip([0] + indexes, indexes)]
496
+
497
+ self.conv_in = nn.Sequential(
498
+ nn.Conv3d(
499
+ in_dim,
500
+ encoder_dim,
501
+ kernel_size=(3, 3, 3),
502
+ stride=(1, 1, 1),
503
+ padding=(1, 1, 1),
504
+ ),
505
+ nn.BatchNorm3d(encoder_dim),
506
+ nn.GELU()
507
+ )
508
+ self.translator1 = TF_AwareBlocks(dim=encoder_dim * frame, num_blocks=block_list[0], drop_path=dp_list[0], mlp_ratio=mlp_ratio, large_kernel=51, small_kernel=5)
509
+
510
+ self.wavelet_embed1 = Wavelet_3D_Embedding(in_dim=encoder_dim, out_dim=encoder_dim * 2, emb_dim=in_dim) # wavelet_recon2: hi_dim = in_dim
511
+
512
+ self.bottleneck_translator = TF_AwareBlocks(dim=encoder_dim * 2 * frame, num_blocks=block_list[1], drop_path=dp_list[1], use_bottleneck=True, mlp_ratio=mlp_ratio, large_kernel=21, small_kernel=5)
513
+
514
+ self.wavelet_recon1 = Wavelet_3D_Reconstruction(in_dim=encoder_dim * 2, out_dim=encoder_dim, hi_dim=encoder_dim)
515
+ self.translator2 = TF_AwareBlocks(dim=encoder_dim * frame, num_blocks=block_list[2], drop_path=dp_list[2], use_hid=True, mlp_ratio=mlp_ratio, large_kernel=51, small_kernel=5)
516
+
517
+ self.conv_out = nn.Sequential(
518
+ nn.BatchNorm3d(encoder_dim),
519
+ nn.GELU(),
520
+ nn.Conv3d(
521
+ encoder_dim,
522
+ in_dim,
523
+ kernel_size=(3, 3, 3),
524
+ stride=(1, 1, 1),
525
+ padding=(1, 1, 1))
526
+ )
527
+
528
+ def update_drop_path(self, drop_path_rate):
529
+ dp_list = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.block_list))]
530
+ indexes = list(accumulate(self.block_list))
531
+ dp_lists = [dp_list[start:end] for start, end in zip([0] + indexes, indexes)]
532
+ dp_apply_blocks = [self.translator1.blocks, self.bottleneck_translator.blocks, self.translator2.blocks]
533
+ for translators, dp_list_translators in zip(dp_apply_blocks, dp_lists):
534
+ for translator, dp_list_translator in zip(translators, dp_list_translators):
535
+ translator.drop_path.drop_prob = dp_list_translator
536
+
537
+ def forward(self, x):
538
+ x = rearrange(x, 'b t c h w -> b c t h w')
539
+
540
+ ori_img = x
541
+ x = self.conv_in(x)
542
+
543
+ x, tskip1 = self.translator1(x)
544
+ x, skip1 = self.wavelet_embed1(x, x_emb=ori_img)
545
+
546
+ x = self.bottleneck_translator(x)
547
+
548
+ x = self.wavelet_recon1(x, skip1)
549
+ x = self.translator2(x, tskip1)
550
+
551
+ x = self.conv_out(x)
552
+
553
+ x = rearrange(x, 'b c t h w -> b t c h w')
554
+ return x
555
+
556
+
557
+
558
+
559
+ if __name__ == "__main__":
560
+ from fvcore.nn import FlopCountAnalysis, flop_count_table
561
+ # import os
562
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "3"
563
+
564
+
565
+ model = WaST_level1(in_shape=(4, 2, 32, 32), encoder_dim=20, block_list=[2, 8, 2]).cuda()
566
+ print(model)
567
+ dummy_tensor = torch.rand(1, 4, 2, 32, 32).cuda()
568
+ output = model(dummy_tensor)
569
+ print(f"input shape is {dummy_tensor.shape}, output shape is {output.shape}...")
570
+ flops = FlopCountAnalysis(model, dummy_tensor)
571
+ print(flop_count_table(flops))
572
+
573
+
574
+
575
+
576
+
577
+