Upload 50 files
Browse files- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/upload.iml +12 -0
- .idea/workspace.xml +42 -0
- model_LARRES.py +229 -0
- model_convlstm.py +186 -0
- modules.py +66 -0
- test2015.h5 +3 -0
- test2020.h5 +3 -0
- train2015.h5 +3 -0
- train2020.h5 +3 -0
- train_simvp2.py +85 -0
- utilpack/__init__.py +32 -0
- utilpack/__pycache__/__init__.cpython-312.pyc +0 -0
- utilpack/__pycache__/convlstm_modules.cpython-312.pyc +0 -0
- utilpack/__pycache__/e3dlstm_modules.cpython-312.pyc +0 -0
- utilpack/__pycache__/mau_modules.cpython-312.pyc +0 -0
- utilpack/__pycache__/mim_modules.cpython-312.pyc +0 -0
- utilpack/__pycache__/mmvp_modules.cpython-312.pyc +0 -0
- utilpack/__pycache__/phydnet_modules.cpython-312.pyc +0 -0
- utilpack/__pycache__/predrnn_modules.cpython-312.pyc +0 -0
- utilpack/__pycache__/predrnnpp_modules.cpython-312.pyc +0 -0
- utilpack/__pycache__/predrnnv2_modules.cpython-312.pyc +0 -0
- utilpack/__pycache__/simvp_modules.cpython-312.pyc +0 -0
- utilpack/__pycache__/swinlstm_modules.cpython-312.pyc +0 -0
- utilpack/convlstm_modules.py +58 -0
- utilpack/e3dlstm_modules.py +119 -0
- utilpack/layers/__init__.py +10 -0
- utilpack/layers/__pycache__/__init__.cpython-312.pyc +0 -0
- utilpack/layers/__pycache__/hornet.cpython-312.pyc +0 -0
- utilpack/layers/__pycache__/moganet.cpython-312.pyc +0 -0
- utilpack/layers/__pycache__/poolformer.cpython-312.pyc +0 -0
- utilpack/layers/__pycache__/uniformer.cpython-312.pyc +0 -0
- utilpack/layers/__pycache__/van.cpython-312.pyc +0 -0
- utilpack/layers/hornet.py +112 -0
- utilpack/layers/moganet.py +140 -0
- utilpack/layers/poolformer.py +97 -0
- utilpack/layers/uniformer.py +156 -0
- utilpack/layers/van.py +119 -0
- utilpack/mau_modules.py +66 -0
- utilpack/mim_modules.py +211 -0
- utilpack/mmvp_modules.py +349 -0
- utilpack/phydnet_modules.py +463 -0
- utilpack/predrnn_modules.py +79 -0
- utilpack/predrnnpp_modules.py +169 -0
- utilpack/predrnnv2_modules.py +82 -0
- utilpack/simvp_modules.py +586 -0
- utilpack/swinlstm_modules.py +317 -0
- utilpack/wast_modules.py +577 -0
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="Black">
|
| 4 |
+
<option name="sdkName" value="Python 3.12" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12" project-jdk-type="Python SDK" />
|
| 7 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/upload.iml" filepath="$PROJECT_DIR$/.idea/upload.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/upload.iml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="inheritedJdk" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
<component name="PyDocumentationSettings">
|
| 9 |
+
<option name="format" value="PLAIN" />
|
| 10 |
+
<option name="myDocStringFormat" value="Plain" />
|
| 11 |
+
</component>
|
| 12 |
+
</module>
|
.idea/workspace.xml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ChangeListManager">
|
| 4 |
+
<list default="true" id="9591cce3-c276-4022-8bb6-62a293d16241" name="更改" comment="" />
|
| 5 |
+
<option name="SHOW_DIALOG" value="false" />
|
| 6 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
| 7 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
| 8 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
| 9 |
+
</component>
|
| 10 |
+
<component name="ProjectColorInfo"><![CDATA[{
|
| 11 |
+
"associatedIndex": 3
|
| 12 |
+
}]]></component>
|
| 13 |
+
<component name="ProjectId" id="2ssyDJvvBdJ2oeAvJVoyJMXumP7" />
|
| 14 |
+
<component name="ProjectViewState">
|
| 15 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
| 16 |
+
<option name="showLibraryContents" value="true" />
|
| 17 |
+
</component>
|
| 18 |
+
<component name="PropertiesComponent"><![CDATA[{
|
| 19 |
+
"keyToString": {
|
| 20 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
| 21 |
+
"last_opened_file_path": "C:/Users/Administrator/Desktop/upload"
|
| 22 |
+
}
|
| 23 |
+
}]]></component>
|
| 24 |
+
<component name="SharedIndexes">
|
| 25 |
+
<attachedChunks>
|
| 26 |
+
<set>
|
| 27 |
+
<option value="bundled-python-sdk-98f27166c754-ba05f1cad1b1-com.jetbrains.pycharm.community.sharedIndexes.bundled-PC-242.21829.153" />
|
| 28 |
+
</set>
|
| 29 |
+
</attachedChunks>
|
| 30 |
+
</component>
|
| 31 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
|
| 32 |
+
<component name="TaskManager">
|
| 33 |
+
<task active="true" id="Default" summary="默认任务">
|
| 34 |
+
<changelist id="9591cce3-c276-4022-8bb6-62a293d16241" name="更改" comment="" />
|
| 35 |
+
<created>1739258456359</created>
|
| 36 |
+
<option name="number" value="Default" />
|
| 37 |
+
<option name="presentableId" value="Default" />
|
| 38 |
+
<updated>1739258456359</updated>
|
| 39 |
+
</task>
|
| 40 |
+
<servers />
|
| 41 |
+
</component>
|
| 42 |
+
</project>
|
model_LARRES.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from modules import ConvSC, Inception
|
| 4 |
+
|
| 5 |
+
from utilpack import (ConvNeXtSubBlock, ConvMixerSubBlock, GASubBlock, gInception_ST,
|
| 6 |
+
HorNetSubBlock, MLPMixerSubBlock, MogaSubBlock, PoolFormerSubBlock,
|
| 7 |
+
SwinSubBlock, UniformerSubBlock, VANSubBlock, ViTSubBlock, TAUSubBlock)
|
| 8 |
+
|
| 9 |
+
def stride_generator(N, reverse=False):
|
| 10 |
+
strides = [1, 2]*10
|
| 11 |
+
if reverse: return list(reversed(strides[:N]))
|
| 12 |
+
else: return strides[:N]
|
| 13 |
+
|
| 14 |
+
class Encoder(nn.Module):
|
| 15 |
+
def __init__(self,C_in, C_hid, N_S):
|
| 16 |
+
super(Encoder,self).__init__()
|
| 17 |
+
strides = stride_generator(N_S)
|
| 18 |
+
self.enc = nn.Sequential(
|
| 19 |
+
ConvSC(C_in, C_hid, stride=strides[0]),
|
| 20 |
+
*[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def forward(self,x):# B*4, 3, 128, 128
|
| 24 |
+
enc1 = self.enc[0](x)
|
| 25 |
+
latent = enc1
|
| 26 |
+
for i in range(1,len(self.enc)):
|
| 27 |
+
latent = self.enc[i](latent)
|
| 28 |
+
return latent,enc1
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Decoder(nn.Module):
|
| 32 |
+
def __init__(self,C_hid, C_out, N_S):
|
| 33 |
+
super(Decoder,self).__init__()
|
| 34 |
+
strides = stride_generator(N_S, reverse=True)
|
| 35 |
+
self.dec = nn.Sequential(
|
| 36 |
+
*[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
|
| 37 |
+
ConvSC(2*C_hid, C_hid, stride=strides[-1], transpose=True)
|
| 38 |
+
)
|
| 39 |
+
self.readout = nn.Conv2d(C_hid, C_out, 1)
|
| 40 |
+
|
| 41 |
+
def forward(self, hid, enc1=None):
|
| 42 |
+
for i in range(0,len(self.dec)-1):
|
| 43 |
+
hid = self.dec[i](hid)
|
| 44 |
+
Y = self.dec[-1](torch.cat([hid, enc1], dim=1))
|
| 45 |
+
Y = self.readout(Y)
|
| 46 |
+
return Y
|
| 47 |
+
|
| 48 |
+
class Mid_Xnet(nn.Module):
|
| 49 |
+
def __init__(self, channel_in, channel_hid, N_T, incep_ker = [3,5,7,11], groups=8):
|
| 50 |
+
super(Mid_Xnet, self).__init__()
|
| 51 |
+
|
| 52 |
+
self.N_T = N_T
|
| 53 |
+
enc_layers = [Inception(channel_in, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
|
| 54 |
+
for i in range(1, N_T-1):
|
| 55 |
+
enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
|
| 56 |
+
enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
|
| 57 |
+
|
| 58 |
+
dec_layers = [Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
|
| 59 |
+
for i in range(1, N_T-1):
|
| 60 |
+
dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
|
| 61 |
+
dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_in, incep_ker= incep_ker, groups=groups))
|
| 62 |
+
|
| 63 |
+
self.enc = nn.Sequential(*enc_layers)
|
| 64 |
+
self.dec = nn.Sequential(*dec_layers)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
B, T, C, H, W = x.shape
|
| 68 |
+
x = x.reshape(B, T*C, H, W)
|
| 69 |
+
|
| 70 |
+
# encoder
|
| 71 |
+
skips = []
|
| 72 |
+
z = x
|
| 73 |
+
for i in range(self.N_T):
|
| 74 |
+
z = self.enc[i](z)
|
| 75 |
+
if i < self.N_T - 1:
|
| 76 |
+
skips.append(z)
|
| 77 |
+
|
| 78 |
+
# decoder
|
| 79 |
+
z = self.dec[0](z)
|
| 80 |
+
for i in range(1, self.N_T):
|
| 81 |
+
z = self.dec[i](torch.cat([z, skips[-i]], dim=1))
|
| 82 |
+
|
| 83 |
+
y = z.reshape(B, T, C, H, W)
|
| 84 |
+
return y
|
| 85 |
+
|
| 86 |
+
class MetaBlock(nn.Module):
|
| 87 |
+
"""The hidden Translator of MetaFormer for SimVP"""
|
| 88 |
+
|
| 89 |
+
def __init__(self, in_channels, out_channels, input_resolution=None, model_type=None,
|
| 90 |
+
mlp_ratio=8., drop=0.0, drop_path=0.0, layer_i=0):
|
| 91 |
+
super(MetaBlock, self).__init__()
|
| 92 |
+
self.in_channels = in_channels
|
| 93 |
+
self.out_channels = out_channels
|
| 94 |
+
model_type = model_type.lower() if model_type is not None else 'gsta'
|
| 95 |
+
|
| 96 |
+
if model_type == 'gsta':
|
| 97 |
+
self.block = GASubBlock(
|
| 98 |
+
in_channels, kernel_size=21, mlp_ratio=mlp_ratio,
|
| 99 |
+
drop=drop, drop_path=drop_path, act_layer=nn.GELU)
|
| 100 |
+
elif model_type == 'convmixer':
|
| 101 |
+
self.block = ConvMixerSubBlock(in_channels, kernel_size=11, activation=nn.GELU)
|
| 102 |
+
elif model_type == 'convnext':
|
| 103 |
+
self.block = ConvNeXtSubBlock(
|
| 104 |
+
in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
|
| 105 |
+
elif model_type == 'hornet':
|
| 106 |
+
self.block = HorNetSubBlock(in_channels, mlp_ratio=mlp_ratio, drop_path=drop_path)
|
| 107 |
+
elif model_type in ['mlp', 'mlpmixer']:
|
| 108 |
+
self.block = MLPMixerSubBlock(
|
| 109 |
+
in_channels, input_resolution, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
|
| 110 |
+
elif model_type in ['moga', 'moganet']:
|
| 111 |
+
self.block = MogaSubBlock(
|
| 112 |
+
in_channels, mlp_ratio=mlp_ratio, drop_rate=drop, drop_path_rate=drop_path)
|
| 113 |
+
elif model_type == 'poolformer':
|
| 114 |
+
self.block = PoolFormerSubBlock(
|
| 115 |
+
in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
|
| 116 |
+
elif model_type == 'swin':
|
| 117 |
+
self.block = SwinSubBlock(
|
| 118 |
+
in_channels, input_resolution, layer_i=layer_i, mlp_ratio=mlp_ratio,
|
| 119 |
+
drop=drop, drop_path=drop_path)
|
| 120 |
+
elif model_type == 'uniformer':
|
| 121 |
+
block_type = 'MHSA' if in_channels == out_channels and layer_i > 0 else 'Conv'
|
| 122 |
+
self.block = UniformerSubBlock(
|
| 123 |
+
in_channels, mlp_ratio=mlp_ratio, drop=drop,
|
| 124 |
+
drop_path=drop_path, block_type=block_type)
|
| 125 |
+
elif model_type == 'van':
|
| 126 |
+
self.block = VANSubBlock(
|
| 127 |
+
in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path, act_layer=nn.GELU)
|
| 128 |
+
elif model_type == 'vit':
|
| 129 |
+
self.block = ViTSubBlock(
|
| 130 |
+
in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
|
| 131 |
+
else:
|
| 132 |
+
assert False and "Invalid model_type in SimVP"
|
| 133 |
+
|
| 134 |
+
if in_channels != out_channels:
|
| 135 |
+
self.reduction = nn.Conv2d(
|
| 136 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
z = self.block(x)
|
| 140 |
+
return z if self.in_channels == self.out_channels else self.reduction(z)
|
| 141 |
+
|
| 142 |
+
class MidMetaNet(nn.Module):
|
| 143 |
+
"""The hidden Translator of MetaFormer for SimVP"""
|
| 144 |
+
|
| 145 |
+
def __init__(self, channel_in, channel_hid, N2,
|
| 146 |
+
input_resolution=None, model_type=None,
|
| 147 |
+
mlp_ratio=4., drop=0.0, drop_path=0.1):
|
| 148 |
+
super(MidMetaNet, self).__init__()
|
| 149 |
+
assert N2 >= 2 and mlp_ratio > 1
|
| 150 |
+
self.N2 = N2
|
| 151 |
+
dpr = [ # stochastic depth decay rule
|
| 152 |
+
x.item() for x in torch.linspace(1e-2, drop_path, self.N2)]
|
| 153 |
+
|
| 154 |
+
# downsample
|
| 155 |
+
enc_layers = [MetaBlock(
|
| 156 |
+
channel_in, channel_hid, input_resolution, model_type,
|
| 157 |
+
mlp_ratio, drop, drop_path=dpr[0], layer_i=0)]
|
| 158 |
+
# middle layers
|
| 159 |
+
for i in range(1, N2-1):
|
| 160 |
+
enc_layers.append(MetaBlock(
|
| 161 |
+
channel_hid, channel_hid, input_resolution, model_type,
|
| 162 |
+
mlp_ratio, drop, drop_path=dpr[i], layer_i=i))
|
| 163 |
+
# upsample
|
| 164 |
+
enc_layers.append(MetaBlock(
|
| 165 |
+
channel_hid, channel_in, input_resolution, model_type,
|
| 166 |
+
mlp_ratio, drop, drop_path=drop_path, layer_i=N2-1))
|
| 167 |
+
self.enc = nn.Sequential(*enc_layers)
|
| 168 |
+
|
| 169 |
+
def forward(self, x):
|
| 170 |
+
B, T, C, H, W = x.shape
|
| 171 |
+
x = x.reshape(B, T*C, H, W)
|
| 172 |
+
|
| 173 |
+
z = x
|
| 174 |
+
for i in range(self.N2):
|
| 175 |
+
z = self.enc[i](z)
|
| 176 |
+
|
| 177 |
+
y = z.reshape(B, T, C, H, W)
|
| 178 |
+
return y
|
| 179 |
+
|
| 180 |
+
class SimVP(nn.Module):
|
| 181 |
+
def __init__(self, hid_S=32, hid_T=256, N_S=2, N_T=8, incep_ker=[3,5,7,11], groups=4):
|
| 182 |
+
super(SimVP, self).__init__()
|
| 183 |
+
T, C, H, W = 36,1,72,72
|
| 184 |
+
self.enc = Encoder(C, hid_S, N_S)
|
| 185 |
+
self.hid = MidMetaNet(T * hid_S, hid_T, N_T,
|
| 186 |
+
input_resolution=(H, W), model_type="vit",
|
| 187 |
+
mlp_ratio=8, drop=0.0, drop_path=0.1)
|
| 188 |
+
self.dec = Decoder(hid_S, C, N_S)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def forward(self, x_raw):
|
| 192 |
+
B, T, C, H, W = x_raw.shape
|
| 193 |
+
x = x_raw.view(B*T, C, H, W)
|
| 194 |
+
|
| 195 |
+
embed, skip = self.enc(x)
|
| 196 |
+
_, C_, H_, W_ = embed.shape
|
| 197 |
+
|
| 198 |
+
z = embed.view(B, T, C_, H_, W_)
|
| 199 |
+
hid = self.hid(z)
|
| 200 |
+
hid = hid.reshape(B*T, C_, H_, W_)
|
| 201 |
+
|
| 202 |
+
Y = self.dec(hid, skip)
|
| 203 |
+
Y = Y.reshape(B, T, C, H, W)
|
| 204 |
+
return Y
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class larres(nn.Module):
|
| 208 |
+
def __init__(self, hid_S=32, hid_T=256, N_S=2, N_T=8, incep_ker=[3,5,7,11], groups=4):
|
| 209 |
+
super(larres, self).__init__()
|
| 210 |
+
T, C, H, W = 36,1,72,72
|
| 211 |
+
self.enc = Encoder(C, hid_S, N_S)
|
| 212 |
+
self.hid = Mid_Xnet(T * hid_S, hid_T, N_T, incep_ker, groups)
|
| 213 |
+
self.dec = Decoder(hid_S, C, N_S)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def forward(self, x_raw):
|
| 217 |
+
B, T, C, H, W = x_raw.shape
|
| 218 |
+
x = x_raw.view(B*T, C, H, W)
|
| 219 |
+
|
| 220 |
+
embed, skip = self.enc(x)
|
| 221 |
+
_, C_, H_, W_ = embed.shape
|
| 222 |
+
|
| 223 |
+
z = embed.view(B, T, C_, H_, W_)
|
| 224 |
+
hid = self.hid(z)
|
| 225 |
+
hid = hid.reshape(B*T, C_, H_, W_)
|
| 226 |
+
|
| 227 |
+
Y = self.dec(hid, skip)
|
| 228 |
+
Y = Y.reshape(B, T, C, H, W)
|
| 229 |
+
return Y
|
model_convlstm.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import nn, Tensor
|
| 4 |
+
import numpy as np
|
| 5 |
+
import h5py
|
| 6 |
+
from torch.utils.data import DataLoader, Dataset
|
| 7 |
+
from torch.utils.data import Subset
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
|
| 10 |
+
#Obtained from: https://holmdk.github.io/2020/04/02/video_prediction.html
|
| 11 |
+
class ConvLSTMCell(nn.Module):
|
| 12 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
|
| 13 |
+
"""
|
| 14 |
+
Initialize ConvLSTM cell.
|
| 15 |
+
|
| 16 |
+
Parameters
|
| 17 |
+
----------
|
| 18 |
+
input_dim: int
|
| 19 |
+
Number of channels of input tensor.
|
| 20 |
+
hidden_dim: int
|
| 21 |
+
Number of channels of hidden state.
|
| 22 |
+
kernel_size: (int, int)
|
| 23 |
+
Size of the convolutional kernel.
|
| 24 |
+
bias: bool
|
| 25 |
+
Whether or not to add the bias.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.input_dim = input_dim
|
| 31 |
+
self.hidden_dim = hidden_dim
|
| 32 |
+
|
| 33 |
+
self.kernel_size = kernel_size
|
| 34 |
+
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
|
| 35 |
+
self.bias = bias
|
| 36 |
+
|
| 37 |
+
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
|
| 38 |
+
out_channels=4 * self.hidden_dim,
|
| 39 |
+
kernel_size=self.kernel_size,
|
| 40 |
+
padding=self.padding,
|
| 41 |
+
bias=self.bias)
|
| 42 |
+
|
| 43 |
+
def forward(self, input_tensor, cur_state):
|
| 44 |
+
h_cur, c_cur = cur_state
|
| 45 |
+
|
| 46 |
+
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
|
| 47 |
+
|
| 48 |
+
combined_conv = self.conv(combined)
|
| 49 |
+
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
|
| 50 |
+
i = torch.sigmoid(cc_i)
|
| 51 |
+
f = torch.sigmoid(cc_f)
|
| 52 |
+
o = torch.sigmoid(cc_o)
|
| 53 |
+
g = torch.tanh(cc_g)
|
| 54 |
+
|
| 55 |
+
c_next = f * c_cur + i * g
|
| 56 |
+
h_next = o * torch.tanh(c_next)
|
| 57 |
+
|
| 58 |
+
return h_next, c_next
|
| 59 |
+
|
| 60 |
+
def init_hidden(self, batch_size, image_size):
|
| 61 |
+
height, width = image_size
|
| 62 |
+
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
|
| 63 |
+
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
|
| 64 |
+
|
| 65 |
+
def process_highdim_array(arr):
|
| 66 |
+
"""
|
| 67 |
+
处理形状为 (1, 60, 1, 71, 73) 的高维数组,将最后两个维度从 (71, 73) 变为 (72, 72)。
|
| 68 |
+
|
| 69 |
+
参数:
|
| 70 |
+
arr (ndarray): 输入的高维 numpy 数组,假设形状为 (1, 60, 1, 71, 73)。
|
| 71 |
+
|
| 72 |
+
返回:
|
| 73 |
+
ndarray: 处理后的数组,形状为 (1, 60, 1, 72, 72)。
|
| 74 |
+
"""
|
| 75 |
+
# 检查数组的最后两个维度是否为 (71, 73)
|
| 76 |
+
if arr.shape[-2:] != (71, 73):
|
| 77 |
+
raise ValueError("输入数组的最后两个维度必须是 (71, 73)")
|
| 78 |
+
|
| 79 |
+
# 对最后两个维度的 (71, 73) 进行处理
|
| 80 |
+
# 去掉最后一个维度,变成 (71, 72)
|
| 81 |
+
arr_trimmed = arr[..., :-1]
|
| 82 |
+
|
| 83 |
+
# 在倒数第二个维度填充一行 0,变成 (72, 72)
|
| 84 |
+
arr_padded = np.pad(arr_trimmed, ((0, 0), (0, 0), (0, 1), (0, 0)), mode='constant', constant_values=0)
|
| 85 |
+
|
| 86 |
+
return arr_padded
|
| 87 |
+
|
| 88 |
+
class ionexDataset(Dataset):
|
| 89 |
+
def __init__(self, npy_data, nstepsin=36, nstepsout=12, stride=12):
|
| 90 |
+
self.data = npy_data.astype(np.float32)
|
| 91 |
+
self.nstepsin=nstepsin
|
| 92 |
+
self.nstepsout=nstepsout
|
| 93 |
+
self.stride=stride
|
| 94 |
+
self.idx=np.arange(0,len(self.data)-nstepsout-nstepsin+1,stride)
|
| 95 |
+
|
| 96 |
+
def __getitem__(self, index):
|
| 97 |
+
# find the end of this pattern
|
| 98 |
+
i=self.idx[index]
|
| 99 |
+
end_ix = i + self.nstepsin
|
| 100 |
+
# check if we are beyond the sequence
|
| 101 |
+
if end_ix + self.nstepsout> len(self.data):
|
| 102 |
+
return None,None
|
| 103 |
+
# gather input and output parts of the pattern
|
| 104 |
+
seq_x, seq_y = self.data[i:end_ix], self.data[end_ix:end_ix+self.nstepsout]
|
| 105 |
+
return process_highdim_array(seq_x),process_highdim_array(seq_y)
|
| 106 |
+
|
| 107 |
+
def __len__(self):
|
| 108 |
+
return len(self.idx)
|
| 109 |
+
|
| 110 |
+
def split_train_val(self, val_split=0.25):
|
| 111 |
+
train_idx, val_idx = train_test_split(list(range(len(self))), test_size=val_split)
|
| 112 |
+
return Subset(self, train_idx), Subset(self, val_idx)
|
| 113 |
+
|
| 114 |
+
nstepsin=36
|
| 115 |
+
nstepsout=12
|
| 116 |
+
stride=12
|
| 117 |
+
max_epochs=200
|
| 118 |
+
# batch_size=2
|
| 119 |
+
|
| 120 |
+
f = h5py.File('train2015.h5', 'r')
|
| 121 |
+
train_npy=np.array(f['2020'])/10
|
| 122 |
+
f = h5py.File('test2015.h5', 'r')
|
| 123 |
+
test_npy=np.array(f['2015'])/10
|
| 124 |
+
|
| 125 |
+
# f = h5py.File('train2015.h5', 'r')
|
| 126 |
+
# train_npy=np.array(f['2020'])/10
|
| 127 |
+
# f1=h5py.File('c1pg2015.h5', 'r')
|
| 128 |
+
# f = h5py.File('test2015.h5', 'r')
|
| 129 |
+
# test_npy=np.array(f['2015'])/10-np.array(f1['2015'])/10
|
| 130 |
+
|
| 131 |
+
# f = h5py.File('train2020.h5', 'r')
|
| 132 |
+
# train_npy=np.array(f['2020'])/10
|
| 133 |
+
# f = h5py.File('test2020.h5', 'r')
|
| 134 |
+
# test_npy=np.array(f['2020'])/10
|
| 135 |
+
|
| 136 |
+
# f = h5py.File('train2020.h5', 'r')
|
| 137 |
+
# train_npy=np.array(f['2020'])/10
|
| 138 |
+
# f1=h5py.File('c1pg2020.h5', 'r')
|
| 139 |
+
# f = h5py.File('test2020.h5', 'r')
|
| 140 |
+
# test_npy=np.array(f['2020'])/10-np.array(f1['2020'])/10
|
| 141 |
+
|
| 142 |
+
f.close()
|
| 143 |
+
print("Training data:", train_npy.shape)
|
| 144 |
+
print("Testing data:", test_npy.shape)
|
| 145 |
+
|
| 146 |
+
class EncoderDecoderConvLSTM(nn.Module):
|
| 147 |
+
def __init__(self, nf, in_chan, out_chan, nstepsout=12):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.nstepsout=nstepsout
|
| 150 |
+
self.encoder_1_convlstm = ConvLSTMCell(input_dim=in_chan, hidden_dim=nf, kernel_size=(3, 3), bias=True)
|
| 151 |
+
self.encoder_2_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
|
| 152 |
+
self.encoder_3_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
|
| 153 |
+
self.decoder_1_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
|
| 154 |
+
self.decoder_2_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
|
| 155 |
+
self.decoder_3_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
|
| 156 |
+
self.conv2d = nn.Conv2d(in_channels=nf, out_channels=1, kernel_size=(1,1))
|
| 157 |
+
|
| 158 |
+
def forward(self, x, future_seq=0, hidden_state=None):
|
| 159 |
+
b, seq_len, _, h, w = x.size()
|
| 160 |
+
|
| 161 |
+
# encoder
|
| 162 |
+
# initialize hidden states
|
| 163 |
+
h1, c1 = self.encoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
|
| 164 |
+
h2, c2 = self.encoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
|
| 165 |
+
h3, c3 = self.decoder_3_convlstm.init_hidden(batch_size=b, image_size=(h, w))
|
| 166 |
+
b, seq_len, _, h, w = x.size()
|
| 167 |
+
|
| 168 |
+
for t in range(seq_len):
|
| 169 |
+
h1, c1 = self.encoder_1_convlstm(input_tensor=x[:, t, :, :], cur_state=[h1, c1])
|
| 170 |
+
h2, c2 = self.encoder_2_convlstm(input_tensor=h1, cur_state=[h2, c2])
|
| 171 |
+
h3, c3 = self.encoder_3_convlstm(input_tensor=h2, cur_state=[h3, c3])
|
| 172 |
+
|
| 173 |
+
# decoder
|
| 174 |
+
# initialize hidden states
|
| 175 |
+
h4, c4 = h1, c1 #self.decoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
|
| 176 |
+
h5, c5 = h2, c2 #self.decoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
|
| 177 |
+
h6, c6 = h3, c3 #self.decoder_3_convlstm.init_hidden(batch_size=b, image_size=(h, w))
|
| 178 |
+
|
| 179 |
+
outputs=[]
|
| 180 |
+
for t in range(self.nstepsout):
|
| 181 |
+
h4, c4 = self.decoder_1_convlstm(input_tensor=h3, cur_state=[h4, c4]) #note that h3 is not updated during prediction
|
| 182 |
+
h5, c5 = self.decoder_2_convlstm(input_tensor=h4, cur_state=[h5, c5])
|
| 183 |
+
h6, c6 = self.decoder_3_convlstm(input_tensor=h5, cur_state=[h6, c6])
|
| 184 |
+
outputs.append(self.conv2d(h4))
|
| 185 |
+
outputs = torch.stack(outputs, 1)
|
| 186 |
+
return outputs
|
modules.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BasicConv2d(nn.Module):
|
| 5 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False):
|
| 6 |
+
super(BasicConv2d, self).__init__()
|
| 7 |
+
self.act_norm=act_norm
|
| 8 |
+
if not transpose:
|
| 9 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
| 10 |
+
else:
|
| 11 |
+
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 )
|
| 12 |
+
self.norm = nn.GroupNorm(2, out_channels)
|
| 13 |
+
self.act = nn.LeakyReLU(0.2, inplace=True)
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
y = self.conv(x)
|
| 17 |
+
if self.act_norm:
|
| 18 |
+
y = self.act(self.norm(y))
|
| 19 |
+
return y
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ConvSC(nn.Module):
|
| 23 |
+
def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):
|
| 24 |
+
super(ConvSC, self).__init__()
|
| 25 |
+
if stride == 1:
|
| 26 |
+
transpose = False
|
| 27 |
+
self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride,
|
| 28 |
+
padding=1, transpose=transpose, act_norm=act_norm)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
y = self.conv(x)
|
| 32 |
+
return y
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class GroupConv2d(nn.Module):
|
| 36 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False):
|
| 37 |
+
super(GroupConv2d, self).__init__()
|
| 38 |
+
self.act_norm = act_norm
|
| 39 |
+
if in_channels % groups != 0:
|
| 40 |
+
groups = 1
|
| 41 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups)
|
| 42 |
+
self.norm = nn.GroupNorm(groups,out_channels)
|
| 43 |
+
self.activate = nn.LeakyReLU(0.2, inplace=True)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
y = self.conv(x)
|
| 47 |
+
if self.act_norm:
|
| 48 |
+
y = self.activate(self.norm(y))
|
| 49 |
+
return y
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Inception(nn.Module):
|
| 53 |
+
def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8):
|
| 54 |
+
super(Inception, self).__init__()
|
| 55 |
+
self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
|
| 56 |
+
layers = []
|
| 57 |
+
for ker in incep_ker:
|
| 58 |
+
layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, groups=groups, act_norm=True))
|
| 59 |
+
self.layers = nn.Sequential(*layers)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
x = self.conv1(x)
|
| 63 |
+
y = 0
|
| 64 |
+
for layer in self.layers:
|
| 65 |
+
y += layer(x)
|
| 66 |
+
return y
|
test2015.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8f012ef4a1fc40d1c993cea1eff972ea56cbda86fd3a433431ea71d82259e09
|
| 3 |
+
size 181614368
|
test2020.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f0a1c3c6d19a81a998ab4381bf189ba0ac7b8c6378008ad7b3d1465ffa20edd1
|
| 3 |
+
size 182111936
|
train2015.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f17b0b8430647d2cc21a1bc2af719e3a2370bed8d93ac70626fcc08fcc2e546c
|
| 3 |
+
size 545336576
|
train2020.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f39605870ce1fc5930277d5d225a29b3aaaff8fc53c4a00c9c6149740d91ebf
|
| 3 |
+
size 544839008
|
train_simvp2.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import nn, Tensor
|
| 5 |
+
import numpy as np
|
| 6 |
+
import h5py
|
| 7 |
+
from torch.utils.data import DataLoader, Dataset
|
| 8 |
+
from torch.utils.data import Subset
|
| 9 |
+
from sklearn.model_selection import train_test_split
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
|
| 12 |
+
from model_convlstm import ionexDataset, train_npy, nstepsin, nstepsout, stride, EncoderDecoderConvLSTM, max_epochs
|
| 13 |
+
from model_LARRES import larres
|
| 14 |
+
|
| 15 |
+
# ionexData = ionexDataset(train_npy, nstepsin=nstepsin, nstepsout=nstepsout, stride=stride)
|
| 16 |
+
# train_data, val_data = ionexData.split_train_val(val_split=0.2)
|
| 17 |
+
#
|
| 18 |
+
# train_loader = DataLoader(train_data, batch_size=16, num_workers=0)
|
| 19 |
+
# val_loader = DataLoader(val_data, batch_size=16, num_workers=0)
|
| 20 |
+
|
| 21 |
+
ionexData = ionexDataset(train_npy, nstepsin=nstepsin, nstepsout=nstepsout, stride=stride)
|
| 22 |
+
train_data, val_data = ionexData.split_train_val(val_split=0.2)
|
| 23 |
+
|
| 24 |
+
train_loader = DataLoader(train_data, batch_size=16, num_workers=0)
|
| 25 |
+
val_loader = DataLoader(val_data, batch_size=16, num_workers=0)
|
| 26 |
+
|
| 27 |
+
for X, y in train_loader:
|
| 28 |
+
print(f"Shape of X: {X.shape} {X.dtype} [N, C, H, W]")
|
| 29 |
+
print(f"Shape of Y: {y.shape} {y.dtype}")
|
| 30 |
+
break
|
| 31 |
+
print(f"Training samples: {len(train_loader.dataset)}")
|
| 32 |
+
# print(f"Validation samples: {len(val_loader.dataset)}")
|
| 33 |
+
|
| 34 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 35 |
+
|
| 36 |
+
model=larres().to(device)
|
| 37 |
+
# model.load_state_dict(torch.load("best_model.pth"))
|
| 38 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
| 39 |
+
criterion = nn.L1Loss()
|
| 40 |
+
|
| 41 |
+
# 训练和验证
|
| 42 |
+
|
| 43 |
+
best_val_loss = float('inf')
|
| 44 |
+
# num_epochs = 50
|
| 45 |
+
|
| 46 |
+
for epoch in range(max_epochs):
|
| 47 |
+
# 训练阶段
|
| 48 |
+
model.train()
|
| 49 |
+
all_loss = 0
|
| 50 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
| 51 |
+
data, target = data.to(device), target.to(device) # 将数据和目标迁移到 CUDA
|
| 52 |
+
optimizer.zero_grad()
|
| 53 |
+
output = model(data)
|
| 54 |
+
target_last = target - data[:, 24:36, :, :, :]
|
| 55 |
+
# loss = criterion(output, target_last) # 使用 L1 损失
|
| 56 |
+
loss = criterion(output[:,:12,:,:71,:], target_last[:,:12,:,:71,:]) # 使用 L1 损失
|
| 57 |
+
print(loss)
|
| 58 |
+
all_loss+=loss
|
| 59 |
+
loss.backward()
|
| 60 |
+
optimizer.step()
|
| 61 |
+
|
| 62 |
+
print(f'Epoch {epoch + 1}/{max_epochs}, Train Loss: {all_loss.item():.4f}')
|
| 63 |
+
|
| 64 |
+
# 验证阶段
|
| 65 |
+
model.eval()
|
| 66 |
+
val_loss = 0.0
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
for data, target in val_loader:
|
| 69 |
+
data, target = data.to(device), target.to(device) # 将数据和目标迁移到 CUDA
|
| 70 |
+
output = model(data)
|
| 71 |
+
target_last = target - data[:, 24:36, :, :, :]
|
| 72 |
+
# loss = criterion(output, target_last) # 使用 L1 损失
|
| 73 |
+
loss = criterion(output[:, :12, :, :71, :], target_last[:, :12, :, :71, :]) # 使用 L1 损失
|
| 74 |
+
val_loss += loss.item()
|
| 75 |
+
|
| 76 |
+
val_loss /= len(val_loader)
|
| 77 |
+
print(f'Epoch {epoch + 1}/{max_epochs}, Val Loss: {val_loss:.4f}')
|
| 78 |
+
|
| 79 |
+
# 保存最佳模型
|
| 80 |
+
if val_loss < best_val_loss:
|
| 81 |
+
best_val_loss = val_loss
|
| 82 |
+
torch.save(model.state_dict(), 'best_model.pth')
|
| 83 |
+
print('Best model saved!')
|
| 84 |
+
|
| 85 |
+
print('Training completed.')
|
utilpack/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) CAIRI AI Lab. All rights reserved
|
| 2 |
+
|
| 3 |
+
from .convlstm_modules import ConvLSTMCell
|
| 4 |
+
from .e3dlstm_modules import Eidetic3DLSTMCell, tf_Conv3d
|
| 5 |
+
from .mim_modules import MIMBlock, MIMN
|
| 6 |
+
from .mau_modules import MAUCell
|
| 7 |
+
from .phydnet_modules import PhyCell, PhyD_ConvLSTM, PhyD_EncoderRNN, K2M
|
| 8 |
+
from .predrnn_modules import SpatioTemporalLSTMCell
|
| 9 |
+
from .predrnnpp_modules import CausalLSTMCell, GHU
|
| 10 |
+
from .predrnnv2_modules import SpatioTemporalLSTMCellv2
|
| 11 |
+
from .simvp_modules import (BasicConv2d, ConvSC, GroupConv2d,
|
| 12 |
+
ConvNeXtSubBlock, ConvMixerSubBlock, GASubBlock, gInception_ST,
|
| 13 |
+
HorNetSubBlock, MLPMixerSubBlock, MogaSubBlock, PoolFormerSubBlock,
|
| 14 |
+
SwinSubBlock, UniformerSubBlock, VANSubBlock, ViTSubBlock, TAUSubBlock)
|
| 15 |
+
from .mmvp_modules import (ResBlock, RRDB, ResidualDenseBlock_4C, Up, Conv3D, ConvLayer,
|
| 16 |
+
MatrixPredictor3DConv, SimpleMatrixPredictor3DConv_direct, PredictModel)
|
| 17 |
+
from .swinlstm_modules import UpSample, DownSample, STconvert
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
'ConvLSTMCell', 'CausalLSTMCell', 'GHU', 'SpatioTemporalLSTMCell', 'SpatioTemporalLSTMCellv2',
|
| 21 |
+
'MIMBlock', 'MIMN', 'Eidetic3DLSTMCell', 'tf_Conv3d',
|
| 22 |
+
'PhyCell', 'PhyD_ConvLSTM', 'PhyD_EncoderRNN', 'K2M', 'MAUCell',
|
| 23 |
+
'BasicConv2d', 'ConvSC', 'GroupConv2d',
|
| 24 |
+
'ConvNeXtSubBlock', 'ConvMixerSubBlock', 'GASubBlock', 'gInception_ST',
|
| 25 |
+
'HorNetSubBlock', 'MLPMixerSubBlock', 'MogaSubBlock', 'PoolFormerSubBlock',
|
| 26 |
+
'SwinSubBlock', 'UniformerSubBlock', 'VANSubBlock', 'ViTSubBlock', 'TAUSubBlock',
|
| 27 |
+
'ResBlock', 'RRDB', 'ResidualDenseBlock_4C', 'Up', 'Conv3D', 'ConvLayer',
|
| 28 |
+
'MatrixPredictor3DConv', 'SimpleMatrixPredictor3DConv_direct', 'PredictModel',
|
| 29 |
+
'UpSample', 'DownSample', 'STconvert'
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
]
|
utilpack/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.73 kB). View file
|
|
|
utilpack/__pycache__/convlstm_modules.cpython-312.pyc
ADDED
|
Binary file (3.37 kB). View file
|
|
|
utilpack/__pycache__/e3dlstm_modules.cpython-312.pyc
ADDED
|
Binary file (6.96 kB). View file
|
|
|
utilpack/__pycache__/mau_modules.cpython-312.pyc
ADDED
|
Binary file (4.09 kB). View file
|
|
|
utilpack/__pycache__/mim_modules.cpython-312.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
utilpack/__pycache__/mmvp_modules.cpython-312.pyc
ADDED
|
Binary file (25.7 kB). View file
|
|
|
utilpack/__pycache__/phydnet_modules.cpython-312.pyc
ADDED
|
Binary file (27.2 kB). View file
|
|
|
utilpack/__pycache__/predrnn_modules.cpython-312.pyc
ADDED
|
Binary file (4.65 kB). View file
|
|
|
utilpack/__pycache__/predrnnpp_modules.cpython-312.pyc
ADDED
|
Binary file (9.27 kB). View file
|
|
|
utilpack/__pycache__/predrnnv2_modules.cpython-312.pyc
ADDED
|
Binary file (4.7 kB). View file
|
|
|
utilpack/__pycache__/simvp_modules.cpython-312.pyc
ADDED
|
Binary file (37.5 kB). View file
|
|
|
utilpack/__pycache__/swinlstm_modules.cpython-312.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
utilpack/convlstm_modules.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ConvLSTMCell(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
|
| 8 |
+
super(ConvLSTMCell, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.num_hidden = num_hidden
|
| 11 |
+
self.padding = filter_size // 2
|
| 12 |
+
self._forget_bias = 1.0
|
| 13 |
+
if layer_norm:
|
| 14 |
+
self.conv_x = nn.Sequential(
|
| 15 |
+
nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
|
| 16 |
+
stride=stride, padding=self.padding, bias=False),
|
| 17 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 18 |
+
)
|
| 19 |
+
self.conv_h = nn.Sequential(
|
| 20 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 21 |
+
stride=stride, padding=self.padding, bias=False),
|
| 22 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 23 |
+
)
|
| 24 |
+
self.conv_o = nn.Sequential(
|
| 25 |
+
nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
|
| 26 |
+
stride=stride, padding=self.padding, bias=False),
|
| 27 |
+
nn.LayerNorm([num_hidden, height, width])
|
| 28 |
+
)
|
| 29 |
+
else:
|
| 30 |
+
self.conv_x = nn.Sequential(
|
| 31 |
+
nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
|
| 32 |
+
stride=stride, padding=self.padding, bias=False),
|
| 33 |
+
)
|
| 34 |
+
self.conv_h = nn.Sequential(
|
| 35 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 36 |
+
stride=stride, padding=self.padding, bias=False),
|
| 37 |
+
)
|
| 38 |
+
self.conv_o = nn.Sequential(
|
| 39 |
+
nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
|
| 40 |
+
stride=stride, padding=self.padding, bias=False),
|
| 41 |
+
)
|
| 42 |
+
self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
|
| 43 |
+
stride=1, padding=0, bias=False)
|
| 44 |
+
|
| 45 |
+
def forward(self, x_t, h_t, c_t):
|
| 46 |
+
x_concat = self.conv_x(x_t)
|
| 47 |
+
h_concat = self.conv_h(h_t)
|
| 48 |
+
i_x, f_x, g_x, o_x = torch.split(x_concat, self.num_hidden, dim=1)
|
| 49 |
+
i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
|
| 50 |
+
|
| 51 |
+
i_t = torch.sigmoid(i_x + i_h)
|
| 52 |
+
f_t = torch.sigmoid(f_x + f_h)
|
| 53 |
+
g_t = torch.tanh(g_x + g_h)
|
| 54 |
+
|
| 55 |
+
c_new = f_t * c_t + i_t * g_t
|
| 56 |
+
o_t = torch.sigmoid(o_x + o_h)
|
| 57 |
+
h_new = o_t * torch.tanh(c_new)
|
| 58 |
+
return h_new, c_new
|
utilpack/e3dlstm_modules.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class tf_Conv3d(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, in_channels, out_channels, *vargs, **kwargs):
|
| 9 |
+
super(tf_Conv3d, self).__init__()
|
| 10 |
+
self.conv3d = nn.Conv3d(in_channels, out_channels, *vargs, **kwargs)
|
| 11 |
+
|
| 12 |
+
def forward(self, input):
|
| 13 |
+
return F.interpolate(self.conv3d(input), size=input.shape[-3:], mode="nearest")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Eidetic3DLSTMCell(nn.Module):
|
| 17 |
+
|
| 18 |
+
def __init__(self, in_channel, num_hidden, window_length,
|
| 19 |
+
height, width, filter_size, stride, layer_norm):
|
| 20 |
+
super(Eidetic3DLSTMCell, self).__init__()
|
| 21 |
+
|
| 22 |
+
self._norm_c_t = nn.LayerNorm([num_hidden, window_length, height, width])
|
| 23 |
+
self.num_hidden = num_hidden
|
| 24 |
+
self.padding = (0, filter_size[1] // 2, filter_size[2] // 2)
|
| 25 |
+
self._forget_bias = 1.0
|
| 26 |
+
if layer_norm:
|
| 27 |
+
self.conv_x = nn.Sequential(
|
| 28 |
+
tf_Conv3d(in_channel, num_hidden * 7, kernel_size=filter_size,
|
| 29 |
+
stride=stride, padding=self.padding, bias=False),
|
| 30 |
+
nn.LayerNorm([num_hidden * 7, window_length, height, width])
|
| 31 |
+
)
|
| 32 |
+
self.conv_h = nn.Sequential(
|
| 33 |
+
tf_Conv3d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 34 |
+
stride=stride, padding=self.padding, bias=False),
|
| 35 |
+
nn.LayerNorm([num_hidden * 4, window_length, height, width])
|
| 36 |
+
)
|
| 37 |
+
self.conv_gm = nn.Sequential(
|
| 38 |
+
tf_Conv3d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 39 |
+
stride=stride, padding=self.padding, bias=False),
|
| 40 |
+
nn.LayerNorm([num_hidden * 4, window_length, height, width])
|
| 41 |
+
)
|
| 42 |
+
self.conv_new_cell = nn.Sequential(
|
| 43 |
+
tf_Conv3d(num_hidden, num_hidden, kernel_size=filter_size,
|
| 44 |
+
stride=stride, padding=self.padding, bias=False),
|
| 45 |
+
nn.LayerNorm([num_hidden, window_length, height, width])
|
| 46 |
+
)
|
| 47 |
+
self.conv_new_gm = nn.Sequential(
|
| 48 |
+
tf_Conv3d(num_hidden, num_hidden, kernel_size=filter_size,
|
| 49 |
+
stride=stride, padding=self.padding, bias=False),
|
| 50 |
+
nn.LayerNorm([num_hidden, window_length, height, width])
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
self.conv_x = nn.Sequential(
|
| 54 |
+
tf_Conv3d(in_channel, num_hidden * 7, kernel_size=filter_size,
|
| 55 |
+
stride=stride, padding=self.padding, bias=False),
|
| 56 |
+
)
|
| 57 |
+
self.conv_h = nn.Sequential(
|
| 58 |
+
tf_Conv3d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 59 |
+
stride=stride, padding=self.padding, bias=False),
|
| 60 |
+
)
|
| 61 |
+
self.conv_gm = nn.Sequential(
|
| 62 |
+
tf_Conv3d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 63 |
+
stride=stride, padding=self.padding, bias=False),
|
| 64 |
+
)
|
| 65 |
+
self.conv_new_cell = nn.Sequential(
|
| 66 |
+
tf_Conv3d(num_hidden, num_hidden, kernel_size=filter_size,
|
| 67 |
+
stride=stride, padding=self.padding, bias=False),
|
| 68 |
+
)
|
| 69 |
+
self.conv_new_gm = nn.Sequential(
|
| 70 |
+
tf_Conv3d(num_hidden, num_hidden, kernel_size=filter_size,
|
| 71 |
+
stride=stride, padding=self.padding, bias=False),
|
| 72 |
+
)
|
| 73 |
+
self.conv_last = tf_Conv3d(num_hidden * 2, num_hidden, kernel_size=1,
|
| 74 |
+
stride=1, padding=0, bias=False)
|
| 75 |
+
|
| 76 |
+
def _attn(self, in_query, in_keys, in_values):
|
| 77 |
+
batch, num_channels, _, width, height = in_query.shape
|
| 78 |
+
query = in_query.reshape(batch, -1, num_channels)
|
| 79 |
+
keys = in_keys.reshape(batch, -1, num_channels)
|
| 80 |
+
values = in_values.reshape(batch, -1, num_channels)
|
| 81 |
+
attn = torch.einsum('bxc,byc->bxy', query, keys)
|
| 82 |
+
attn = torch.softmax(attn, dim=2)
|
| 83 |
+
attn = torch.einsum("bxy,byc->bxc", attn, values)
|
| 84 |
+
return attn.reshape(batch, num_channels, -1, width, height)
|
| 85 |
+
|
| 86 |
+
def forward(self, x_t, h_t, c_t, global_memory, eidetic_cell):
|
| 87 |
+
h_concat = self.conv_h(h_t)
|
| 88 |
+
i_h, g_h, r_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
|
| 89 |
+
|
| 90 |
+
x_concat = self.conv_x(x_t)
|
| 91 |
+
i_x, g_x, r_x, o_x, temp_i_x, temp_g_x, temp_f_x = \
|
| 92 |
+
torch.split(x_concat, self.num_hidden, dim=1)
|
| 93 |
+
|
| 94 |
+
i_t = torch.sigmoid(i_x + i_h)
|
| 95 |
+
r_t = torch.sigmoid(r_x + r_h)
|
| 96 |
+
g_t = torch.tanh(g_x + g_h)
|
| 97 |
+
|
| 98 |
+
new_cell = c_t + self._attn(r_t, eidetic_cell, eidetic_cell)
|
| 99 |
+
new_cell = self._norm_c_t(new_cell) + i_t * g_t
|
| 100 |
+
|
| 101 |
+
new_global_memory = self.conv_gm(global_memory)
|
| 102 |
+
i_m, f_m, g_m, m_m = torch.split(new_global_memory, self.num_hidden, dim=1)
|
| 103 |
+
|
| 104 |
+
temp_i_t = torch.sigmoid(temp_i_x + i_m)
|
| 105 |
+
temp_f_t = torch.sigmoid(temp_f_x + f_m + self._forget_bias)
|
| 106 |
+
temp_g_t = torch.tanh(temp_g_x + g_m)
|
| 107 |
+
new_global_memory = temp_f_t * torch.tanh(m_m) + temp_i_t * temp_g_t
|
| 108 |
+
|
| 109 |
+
o_c = self.conv_new_cell(new_cell)
|
| 110 |
+
o_m = self.conv_new_gm(new_global_memory)
|
| 111 |
+
|
| 112 |
+
output_gate = torch.tanh(o_x + o_h + o_c + o_m)
|
| 113 |
+
|
| 114 |
+
memory = torch.cat((new_cell, new_global_memory), 1)
|
| 115 |
+
memory = self.conv_last(memory)
|
| 116 |
+
|
| 117 |
+
output = torch.tanh(memory) * torch.sigmoid(output_gate)
|
| 118 |
+
|
| 119 |
+
return output, new_cell, global_memory
|
utilpack/layers/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .hornet import HorBlock
|
| 2 |
+
from .moganet import ChannelAggregationFFN, MultiOrderGatedAggregation, MultiOrderDWConv
|
| 3 |
+
from .poolformer import PoolFormerBlock
|
| 4 |
+
from .uniformer import CBlock, SABlock
|
| 5 |
+
from .van import DWConv, MixMlp, VANBlock
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'HorBlock', 'ChannelAggregationFFN', 'MultiOrderGatedAggregation', 'MultiOrderDWConv',
|
| 9 |
+
'PoolFormerBlock', 'CBlock', 'SABlock', 'DWConv', 'MixMlp', 'VANBlock',
|
| 10 |
+
]
|
utilpack/layers/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (582 Bytes). View file
|
|
|
utilpack/layers/__pycache__/hornet.cpython-312.pyc
ADDED
|
Binary file (7.78 kB). View file
|
|
|
utilpack/layers/__pycache__/moganet.cpython-312.pyc
ADDED
|
Binary file (8.2 kB). View file
|
|
|
utilpack/layers/__pycache__/poolformer.cpython-312.pyc
ADDED
|
Binary file (6 kB). View file
|
|
|
utilpack/layers/__pycache__/uniformer.cpython-312.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
utilpack/layers/__pycache__/van.cpython-312.pyc
ADDED
|
Binary file (8.25 kB). View file
|
|
|
utilpack/layers/hornet.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# refer to the code from HorNet, Thanks!
|
| 2 |
+
# https://github.com/raoyongming/HorNet
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from timm.layers import DropPath
|
| 8 |
+
import torch.fft
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_dwconv(dim, kernel, bias):
|
| 12 |
+
return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class gnconv(nn.Module):
|
| 16 |
+
def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.order = order
|
| 19 |
+
self.dims = [dim // 2 ** i for i in range(order)]
|
| 20 |
+
self.dims.reverse()
|
| 21 |
+
self.proj_in = nn.Conv2d(dim, 2*dim, 1)
|
| 22 |
+
|
| 23 |
+
if gflayer is None:
|
| 24 |
+
self.dwconv = get_dwconv(sum(self.dims), 7, True)
|
| 25 |
+
else:
|
| 26 |
+
self.dwconv = gflayer(sum(self.dims), h=h, w=w)
|
| 27 |
+
|
| 28 |
+
self.proj_out = nn.Conv2d(dim, dim, 1)
|
| 29 |
+
|
| 30 |
+
self.pws = nn.ModuleList(
|
| 31 |
+
[nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)]
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
self.scale = s
|
| 35 |
+
print('[gnconv]', order, 'order with dims=', self.dims, 'scale=%.4f'%self.scale)
|
| 36 |
+
|
| 37 |
+
def forward(self, x, mask=None, dummy=False):
|
| 38 |
+
fused_x = self.proj_in(x)
|
| 39 |
+
pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
|
| 40 |
+
|
| 41 |
+
dw_abc = self.dwconv(abc) * self.scale
|
| 42 |
+
|
| 43 |
+
dw_list = torch.split(dw_abc, self.dims, dim=1)
|
| 44 |
+
x = pwa * dw_list[0]
|
| 45 |
+
|
| 46 |
+
for i in range(self.order -1):
|
| 47 |
+
x = self.pws[i](x) * dw_list[i+1]
|
| 48 |
+
|
| 49 |
+
x = self.proj_out(x)
|
| 50 |
+
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
class LayerNorm(nn.Module):
|
| 54 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 55 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 56 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 57 |
+
with shape (batch_size, channels, height, width).
|
| 58 |
+
"""
|
| 59 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 62 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 63 |
+
self.eps = eps
|
| 64 |
+
self.data_format = data_format
|
| 65 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 66 |
+
raise NotImplementedError
|
| 67 |
+
self.normalized_shape = (normalized_shape, )
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
if self.data_format == "channels_last":
|
| 71 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 72 |
+
elif self.data_format == "channels_first":
|
| 73 |
+
u = x.mean(1, keepdim=True)
|
| 74 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 75 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 76 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class HorBlock(nn.Module):
|
| 81 |
+
""" HorNet block """
|
| 82 |
+
|
| 83 |
+
def __init__(self, dim, order=4, mlp_ratio=4, drop_path=0., init_value=1e-6, gnconv=gnconv):
|
| 84 |
+
super().__init__()
|
| 85 |
+
|
| 86 |
+
self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first')
|
| 87 |
+
self.gnconv = gnconv(dim, order) # depthwise conv
|
| 88 |
+
self.norm2 = LayerNorm(dim, eps=1e-6)
|
| 89 |
+
self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim)) # pointwise/1x1 convs, implemented with linear layers
|
| 90 |
+
self.act = nn.GELU()
|
| 91 |
+
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
|
| 92 |
+
self.gamma1 = nn.Parameter(init_value * torch.ones(dim), requires_grad=True)
|
| 93 |
+
self.gamma2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
|
| 94 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
B, C, H, W = x.shape
|
| 98 |
+
gamma1 = self.gamma1.view(C, 1, 1)
|
| 99 |
+
x = x + self.drop_path(gamma1 * self.gnconv(self.norm1(x)))
|
| 100 |
+
|
| 101 |
+
input = x
|
| 102 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 103 |
+
x = self.norm2(x)
|
| 104 |
+
x = self.pwconv1(x)
|
| 105 |
+
x = self.act(x)
|
| 106 |
+
x = self.pwconv2(x)
|
| 107 |
+
if self.gamma2 is not None:
|
| 108 |
+
x = self.gamma2 * x
|
| 109 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 110 |
+
|
| 111 |
+
x = input + self.drop_path(x)
|
| 112 |
+
return x
|
utilpack/layers/moganet.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# refer to the code from MogaNet, Thanks!
|
| 2 |
+
# https://github.com/Westlake-AI/MogaNet/blob/main/models/moganet.py
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ChannelAggregationFFN(nn.Module):
|
| 10 |
+
"""An implementation of FFN with Channel Aggregation in MogaNet."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, embed_dims, mlp_hidden_dims, kernel_size=3, act_layer=nn.GELU, ffn_drop=0.):
|
| 13 |
+
super(ChannelAggregationFFN, self).__init__()
|
| 14 |
+
self.embed_dims = embed_dims
|
| 15 |
+
self.mlp_hidden_dims = mlp_hidden_dims
|
| 16 |
+
|
| 17 |
+
self.fc1 = nn.Conv2d(
|
| 18 |
+
in_channels=embed_dims, out_channels=self.mlp_hidden_dims, kernel_size=1)
|
| 19 |
+
self.dwconv = nn.Conv2d(
|
| 20 |
+
in_channels=self.mlp_hidden_dims, out_channels=self.mlp_hidden_dims, kernel_size=kernel_size,
|
| 21 |
+
padding=kernel_size // 2, bias=True, groups=self.mlp_hidden_dims)
|
| 22 |
+
self.act = act_layer()
|
| 23 |
+
self.fc2 = nn.Conv2d(
|
| 24 |
+
in_channels=mlp_hidden_dims, out_channels=embed_dims, kernel_size=1)
|
| 25 |
+
self.drop = nn.Dropout(ffn_drop)
|
| 26 |
+
|
| 27 |
+
self.decompose = nn.Conv2d(
|
| 28 |
+
in_channels=self.mlp_hidden_dims, out_channels=1, kernel_size=1)
|
| 29 |
+
self.sigma = nn.Parameter(
|
| 30 |
+
1e-5 * torch.ones((1, mlp_hidden_dims, 1, 1)), requires_grad=True)
|
| 31 |
+
self.decompose_act = act_layer()
|
| 32 |
+
|
| 33 |
+
def feat_decompose(self, x):
|
| 34 |
+
x = x + self.sigma * (x - self.decompose_act(self.decompose(x)))
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
# proj 1
|
| 39 |
+
x = self.fc1(x)
|
| 40 |
+
x = self.dwconv(x)
|
| 41 |
+
x = self.act(x)
|
| 42 |
+
x = self.drop(x)
|
| 43 |
+
# proj 2
|
| 44 |
+
x = self.feat_decompose(x)
|
| 45 |
+
x = self.fc2(x)
|
| 46 |
+
x = self.drop(x)
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class MultiOrderDWConv(nn.Module):
|
| 51 |
+
"""Multi-order Features with Dilated DWConv Kernel in MogaNet."""
|
| 52 |
+
|
| 53 |
+
def __init__(self, embed_dims, dw_dilation=[1, 2, 3], channel_split=[1, 3, 4]):
|
| 54 |
+
super(MultiOrderDWConv, self).__init__()
|
| 55 |
+
self.split_ratio = [i / sum(channel_split) for i in channel_split]
|
| 56 |
+
self.embed_dims_1 = int(self.split_ratio[1] * embed_dims)
|
| 57 |
+
self.embed_dims_2 = int(self.split_ratio[2] * embed_dims)
|
| 58 |
+
self.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2
|
| 59 |
+
self.embed_dims = embed_dims
|
| 60 |
+
assert len(dw_dilation) == len(channel_split) == 3
|
| 61 |
+
assert 1 <= min(dw_dilation) and max(dw_dilation) <= 3
|
| 62 |
+
assert embed_dims % sum(channel_split) == 0
|
| 63 |
+
|
| 64 |
+
# basic DW conv
|
| 65 |
+
self.DW_conv0 = nn.Conv2d(
|
| 66 |
+
in_channels=self.embed_dims, out_channels=self.embed_dims, kernel_size=5,
|
| 67 |
+
padding=(1 + 4 * dw_dilation[0]) // 2,
|
| 68 |
+
groups=self.embed_dims, stride=1, dilation=dw_dilation[0],
|
| 69 |
+
)
|
| 70 |
+
# DW conv 1
|
| 71 |
+
self.DW_conv1 = nn.Conv2d(
|
| 72 |
+
in_channels=self.embed_dims_1, out_channels=self.embed_dims_1, kernel_size=5,
|
| 73 |
+
padding=(1 + 4 * dw_dilation[1]) // 2,
|
| 74 |
+
groups=self.embed_dims_1, stride=1, dilation=dw_dilation[1],
|
| 75 |
+
)
|
| 76 |
+
# DW conv 2
|
| 77 |
+
self.DW_conv2 = nn.Conv2d(
|
| 78 |
+
in_channels=self.embed_dims_2, out_channels=self.embed_dims_2, kernel_size=7,
|
| 79 |
+
padding=(1 + 6 * dw_dilation[2]) // 2,
|
| 80 |
+
groups=self.embed_dims_2, stride=1, dilation=dw_dilation[2],
|
| 81 |
+
)
|
| 82 |
+
# a channel convolution
|
| 83 |
+
self.PW_conv = nn.Conv2d(
|
| 84 |
+
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
x_0 = self.DW_conv0(x)
|
| 88 |
+
x_1 = self.DW_conv1(
|
| 89 |
+
x_0[:, self.embed_dims_0: self.embed_dims_0+self.embed_dims_1, ...])
|
| 90 |
+
x_2 = self.DW_conv2(
|
| 91 |
+
x_0[:, self.embed_dims-self.embed_dims_2:, ...])
|
| 92 |
+
x = torch.cat([
|
| 93 |
+
x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1)
|
| 94 |
+
x = self.PW_conv(x)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class MultiOrderGatedAggregation(nn.Module):
|
| 99 |
+
"""Spatial Block with Multi-order Gated Aggregation in MogaNet."""
|
| 100 |
+
|
| 101 |
+
def __init__(self, embed_dims, attn_dw_dilation=[1, 2, 3], attn_channel_split=[1, 3, 4], attn_shortcut=True):
|
| 102 |
+
super(MultiOrderGatedAggregation, self).__init__()
|
| 103 |
+
self.embed_dims = embed_dims
|
| 104 |
+
self.attn_shortcut = attn_shortcut
|
| 105 |
+
self.proj_1 = nn.Conv2d(
|
| 106 |
+
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
|
| 107 |
+
self.gate = nn.Conv2d(
|
| 108 |
+
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
|
| 109 |
+
self.value = MultiOrderDWConv(
|
| 110 |
+
embed_dims=embed_dims, dw_dilation=attn_dw_dilation, channel_split=attn_channel_split)
|
| 111 |
+
self.proj_2 = nn.Conv2d(
|
| 112 |
+
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
|
| 113 |
+
|
| 114 |
+
# activation for gating and value
|
| 115 |
+
self.act_value = nn.SiLU()
|
| 116 |
+
self.act_gate = nn.SiLU()
|
| 117 |
+
# decompose
|
| 118 |
+
self.sigma = nn.Parameter(1e-5 * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
|
| 119 |
+
|
| 120 |
+
def feat_decompose(self, x):
|
| 121 |
+
x = self.proj_1(x)
|
| 122 |
+
# x_d: [B, C, H, W] -> [B, C, 1, 1]
|
| 123 |
+
x_d = F.adaptive_avg_pool2d(x, output_size=1)
|
| 124 |
+
x = x + self.sigma * (x - x_d)
|
| 125 |
+
x = self.act_value(x)
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
if self.attn_shortcut:
|
| 130 |
+
shortcut = x.clone()
|
| 131 |
+
# proj 1x1
|
| 132 |
+
x = self.feat_decompose(x)
|
| 133 |
+
# gating and value branch
|
| 134 |
+
g = self.gate(x)
|
| 135 |
+
v = self.value(x)
|
| 136 |
+
# aggregation
|
| 137 |
+
x = self.proj_2(self.act_gate(g) * self.act_gate(v))
|
| 138 |
+
if self.attn_shortcut:
|
| 139 |
+
x = x + shortcut
|
| 140 |
+
return x
|
utilpack/layers/poolformer.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# refer to the code from PoolFormer, Thanks!
|
| 2 |
+
# https://github.com/sail-sg/poolformer/blob/main/models/poolformer.py
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from timm.layers import DropPath, trunc_normal_
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GroupNorm(nn.GroupNorm):
|
| 10 |
+
"""
|
| 11 |
+
Group Normalization with 1 group.
|
| 12 |
+
Input: tensor in shape [B, C, H, W]
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, num_channels, **kwargs):
|
| 15 |
+
super().__init__(1, num_channels, **kwargs)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Pooling(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
Implementation of pooling for PoolFormer
|
| 21 |
+
--pool_size: pooling size
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, pool_size=3):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.pool = nn.AvgPool2d(
|
| 26 |
+
pool_size, stride=1, padding=pool_size//2, count_include_pad=False)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
return self.pool(x) - x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Mlp(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Implementation of MLP with 1*1 convolutions.
|
| 35 |
+
Input: tensor with shape [B, C, H, W]
|
| 36 |
+
"""
|
| 37 |
+
def __init__(self, in_features, hidden_features=None,
|
| 38 |
+
out_features=None, act_layer=nn.GELU, drop=0.):
|
| 39 |
+
super().__init__()
|
| 40 |
+
out_features = out_features or in_features
|
| 41 |
+
hidden_features = hidden_features or in_features
|
| 42 |
+
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
| 43 |
+
self.act = act_layer()
|
| 44 |
+
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
| 45 |
+
self.drop = nn.Dropout(drop)
|
| 46 |
+
self.apply(self._init_weights)
|
| 47 |
+
|
| 48 |
+
def _init_weights(self, m):
|
| 49 |
+
if isinstance(m, nn.Conv2d):
|
| 50 |
+
trunc_normal_(m.weight, std=.02)
|
| 51 |
+
if m.bias is not None:
|
| 52 |
+
nn.init.constant_(m.bias, 0)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
x = self.fc1(x)
|
| 56 |
+
x = self.act(x)
|
| 57 |
+
x = self.drop(x)
|
| 58 |
+
x = self.fc2(x)
|
| 59 |
+
x = self.drop(x)
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PoolFormerBlock(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Implementation of one PoolFormer block.
|
| 66 |
+
--dim: embedding dim
|
| 67 |
+
--pool_size: pooling size
|
| 68 |
+
--mlp_ratio: mlp expansion ratio
|
| 69 |
+
--act_layer: activation
|
| 70 |
+
--norm_layer: normalization
|
| 71 |
+
--drop: dropout rate
|
| 72 |
+
--drop path: Stochastic Depth,
|
| 73 |
+
refer to https://arxiv.org/abs/1603.09382
|
| 74 |
+
--init_value: LayerScale,
|
| 75 |
+
refer to https://arxiv.org/abs/2103.17239
|
| 76 |
+
"""
|
| 77 |
+
def __init__(self, dim, pool_size=3, mlp_ratio=4., drop=0., drop_path=0.,
|
| 78 |
+
init_value=1e-5, act_layer=nn.GELU, norm_layer=GroupNorm):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
self.norm1 = norm_layer(dim)
|
| 82 |
+
self.token_mixer = Pooling(pool_size=pool_size)
|
| 83 |
+
self.norm2 = norm_layer(dim)
|
| 84 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 85 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
| 86 |
+
act_layer=act_layer, drop=drop)
|
| 87 |
+
# The following two techniques are useful to train deep PoolFormers.
|
| 88 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 89 |
+
self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
|
| 90 |
+
self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
x = x + self.drop_path(
|
| 94 |
+
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x)))
|
| 95 |
+
x = x + self.drop_path(
|
| 96 |
+
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
|
| 97 |
+
return x
|
utilpack/layers/uniformer.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# refer to the code from Uniformer, Thanks!
|
| 2 |
+
# https://github.com/Sense-X/UniFormer/blob/main/image_classification/models/uniformer.py
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from timm.layers import DropPath, trunc_normal_
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Mlp(nn.Module):
|
| 11 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 12 |
+
super().__init__()
|
| 13 |
+
out_features = out_features or in_features
|
| 14 |
+
hidden_features = hidden_features or in_features
|
| 15 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 16 |
+
self.act = act_layer()
|
| 17 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 18 |
+
self.drop = nn.Dropout(drop)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
x = self.fc1(x)
|
| 22 |
+
x = self.act(x)
|
| 23 |
+
x = self.drop(x)
|
| 24 |
+
x = self.fc2(x)
|
| 25 |
+
x = self.drop(x)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CMlp(nn.Module):
|
| 30 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 31 |
+
super().__init__()
|
| 32 |
+
out_features = out_features or in_features
|
| 33 |
+
hidden_features = hidden_features or in_features
|
| 34 |
+
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
| 35 |
+
self.act = act_layer()
|
| 36 |
+
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
| 37 |
+
self.drop = nn.Dropout(drop)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
x = self.fc1(x)
|
| 41 |
+
x = self.act(x)
|
| 42 |
+
x = self.drop(x)
|
| 43 |
+
x = self.fc2(x)
|
| 44 |
+
x = self.drop(x)
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Attention(nn.Module):
|
| 49 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.num_heads = num_heads
|
| 52 |
+
head_dim = dim // num_heads
|
| 53 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 54 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 55 |
+
|
| 56 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 57 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 58 |
+
self.proj = nn.Linear(dim, dim)
|
| 59 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
B, N, C = x.shape
|
| 63 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 64 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 65 |
+
|
| 66 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 67 |
+
attn = attn.softmax(dim=-1)
|
| 68 |
+
attn = self.attn_drop(attn)
|
| 69 |
+
|
| 70 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 71 |
+
x = self.proj(x)
|
| 72 |
+
x = self.proj_drop(x)
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class CBlock(nn.Module):
|
| 77 |
+
def __init__(self, dim, num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 78 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
|
| 81 |
+
self.norm1 = nn.BatchNorm2d(dim)
|
| 82 |
+
self.conv1 = nn.Conv2d(dim, dim, 1)
|
| 83 |
+
self.conv2 = nn.Conv2d(dim, dim, 1)
|
| 84 |
+
self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
| 85 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 86 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 87 |
+
self.norm2 = nn.BatchNorm2d(dim)
|
| 88 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 89 |
+
self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 90 |
+
|
| 91 |
+
self.apply(self._init_weights)
|
| 92 |
+
|
| 93 |
+
def _init_weights(self, m):
|
| 94 |
+
if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 95 |
+
nn.init.constant_(m.bias, 0)
|
| 96 |
+
nn.init.constant_(m.weight, 1.0)
|
| 97 |
+
elif isinstance(m, nn.Conv2d):
|
| 98 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 99 |
+
fan_out //= m.groups
|
| 100 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 101 |
+
if m.bias is not None:
|
| 102 |
+
m.bias.data.zero_()
|
| 103 |
+
|
| 104 |
+
@torch.jit.ignore
|
| 105 |
+
def no_weight_decay(self):
|
| 106 |
+
return {}
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
x = x + self.pos_embed(x)
|
| 110 |
+
x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
|
| 111 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class SABlock(nn.Module):
|
| 116 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 117 |
+
drop_path=0., init_value=1e-6, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
|
| 120 |
+
self.norm1 = norm_layer(dim)
|
| 121 |
+
self.attn = Attention(
|
| 122 |
+
dim,
|
| 123 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 124 |
+
attn_drop=attn_drop, proj_drop=drop)
|
| 125 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 126 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 127 |
+
self.norm2 = norm_layer(dim)
|
| 128 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 129 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 130 |
+
# layer scale
|
| 131 |
+
self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
|
| 132 |
+
self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
|
| 133 |
+
|
| 134 |
+
self.apply(self._init_weights)
|
| 135 |
+
|
| 136 |
+
def _init_weights(self, m):
|
| 137 |
+
if isinstance(m, nn.Linear):
|
| 138 |
+
trunc_normal_(m.weight, std=.02)
|
| 139 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 140 |
+
nn.init.constant_(m.bias, 0)
|
| 141 |
+
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 142 |
+
nn.init.constant_(m.bias, 0)
|
| 143 |
+
nn.init.constant_(m.weight, 1.0)
|
| 144 |
+
|
| 145 |
+
@torch.jit.ignore
|
| 146 |
+
def no_weight_decay(self):
|
| 147 |
+
return {'gamma_1', 'gamma_2'}
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
x = x + self.pos_embed(x)
|
| 151 |
+
B, N, H, W = x.shape
|
| 152 |
+
x = x.flatten(2).transpose(1, 2)
|
| 153 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
| 154 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 155 |
+
x = x.transpose(1, 2).reshape(B, N, H, W)
|
| 156 |
+
return x
|
utilpack/layers/van.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# refer to the code from VAN, Thanks!
|
| 2 |
+
# https://github.com/Visual-Attention-Network/VAN-Classification
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from timm.layers import DropPath, trunc_normal_
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DWConv(nn.Module):
|
| 12 |
+
def __init__(self, dim=768):
|
| 13 |
+
super(DWConv, self).__init__()
|
| 14 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
x = self.dwconv(x)
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MixMlp(nn.Module):
|
| 22 |
+
def __init__(self,
|
| 23 |
+
in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 24 |
+
super().__init__()
|
| 25 |
+
out_features = out_features or in_features
|
| 26 |
+
hidden_features = hidden_features or in_features
|
| 27 |
+
self.fc1 = nn.Conv2d(in_features, hidden_features, 1) # 1x1
|
| 28 |
+
self.dwconv = DWConv(hidden_features) # CFF: Convlutional feed-forward network
|
| 29 |
+
self.act = act_layer() # GELU
|
| 30 |
+
self.fc2 = nn.Conv2d(hidden_features, out_features, 1) # 1x1
|
| 31 |
+
self.drop = nn.Dropout(drop)
|
| 32 |
+
self.apply(self._init_weights)
|
| 33 |
+
|
| 34 |
+
def _init_weights(self, m):
|
| 35 |
+
if isinstance(m, nn.Linear):
|
| 36 |
+
trunc_normal_(m.weight, std=.02)
|
| 37 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 38 |
+
nn.init.constant_(m.bias, 0)
|
| 39 |
+
elif isinstance(m, nn.LayerNorm):
|
| 40 |
+
nn.init.constant_(m.bias, 0)
|
| 41 |
+
nn.init.constant_(m.weight, 1.0)
|
| 42 |
+
elif isinstance(m, nn.Conv2d):
|
| 43 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 44 |
+
fan_out //= m.groups
|
| 45 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 46 |
+
if m.bias is not None:
|
| 47 |
+
m.bias.data.zero_()
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
x = self.fc1(x)
|
| 51 |
+
x = self.dwconv(x)
|
| 52 |
+
x = self.act(x)
|
| 53 |
+
x = self.drop(x)
|
| 54 |
+
x = self.fc2(x)
|
| 55 |
+
x = self.drop(x)
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class LKA(nn.Module):
|
| 60 |
+
def __init__(self, dim):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
| 63 |
+
self.conv_spatial = nn.Conv2d(
|
| 64 |
+
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
|
| 65 |
+
self.conv1 = nn.Conv2d(dim, dim, 1)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
u = x.clone()
|
| 70 |
+
attn = self.conv0(x)
|
| 71 |
+
attn = self.conv_spatial(attn)
|
| 72 |
+
attn = self.conv1(attn)
|
| 73 |
+
|
| 74 |
+
return u * attn
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Attention(nn.Module):
|
| 78 |
+
def __init__(self, d_model, attn_shortcut=True):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
self.proj_1 = nn.Conv2d(d_model, d_model, 1)
|
| 82 |
+
self.activation = nn.GELU()
|
| 83 |
+
self.spatial_gating_unit = LKA(d_model)
|
| 84 |
+
self.proj_2 = nn.Conv2d(d_model, d_model, 1)
|
| 85 |
+
self.attn_shortcut = attn_shortcut
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
if self.attn_shortcut:
|
| 89 |
+
shortcut = x.clone()
|
| 90 |
+
x = self.proj_1(x)
|
| 91 |
+
x = self.activation(x)
|
| 92 |
+
x = self.spatial_gating_unit(x)
|
| 93 |
+
x = self.proj_2(x)
|
| 94 |
+
if self.attn_shortcut:
|
| 95 |
+
x = x + shortcut
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class VANBlock(nn.Module):
|
| 100 |
+
def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., init_value=1e-2, act_layer=nn.GELU, attn_shortcut=True):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.norm1 = nn.BatchNorm2d(dim)
|
| 103 |
+
self.attn = Attention(dim, attn_shortcut=attn_shortcut)
|
| 104 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 105 |
+
|
| 106 |
+
self.norm2 = nn.BatchNorm2d(dim)
|
| 107 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 108 |
+
self.mlp = MixMlp(
|
| 109 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 110 |
+
|
| 111 |
+
self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
|
| 112 |
+
self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
|
| 113 |
+
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
x = x + self.drop_path(
|
| 116 |
+
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
|
| 117 |
+
x = x + self.drop_path(
|
| 118 |
+
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
|
| 119 |
+
return x
|
utilpack/mau_modules.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MAUCell(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, tau, cell_mode):
|
| 9 |
+
super(MAUCell, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.num_hidden = num_hidden
|
| 12 |
+
# self.padding = (filter_size[0] // 2, filter_size[1] // 2)
|
| 13 |
+
self.padding = filter_size // 2
|
| 14 |
+
self.cell_mode = cell_mode
|
| 15 |
+
self.d = num_hidden * height * width
|
| 16 |
+
self.tau = tau
|
| 17 |
+
self.states = ['residual', 'normal']
|
| 18 |
+
if not self.cell_mode in self.states:
|
| 19 |
+
raise AssertionError
|
| 20 |
+
self.conv_t = nn.Sequential(
|
| 21 |
+
nn.Conv2d(in_channel, 3 * num_hidden, kernel_size=filter_size,
|
| 22 |
+
stride=stride, padding=self.padding),
|
| 23 |
+
nn.LayerNorm([3 * num_hidden, height, width])
|
| 24 |
+
)
|
| 25 |
+
self.conv_t_next = nn.Sequential(
|
| 26 |
+
nn.Conv2d(in_channel, num_hidden, kernel_size=filter_size,
|
| 27 |
+
stride=stride, padding=self.padding),
|
| 28 |
+
nn.LayerNorm([num_hidden, height, width])
|
| 29 |
+
)
|
| 30 |
+
self.conv_s = nn.Sequential(
|
| 31 |
+
nn.Conv2d(num_hidden, 3 * num_hidden, kernel_size=filter_size,
|
| 32 |
+
stride=stride, padding=self.padding),
|
| 33 |
+
nn.LayerNorm([3 * num_hidden, height, width])
|
| 34 |
+
)
|
| 35 |
+
self.conv_s_next = nn.Sequential(
|
| 36 |
+
nn.Conv2d(num_hidden, num_hidden, kernel_size=filter_size,
|
| 37 |
+
stride=stride, padding=self.padding),
|
| 38 |
+
nn.LayerNorm([num_hidden, height, width])
|
| 39 |
+
)
|
| 40 |
+
self.softmax = nn.Softmax(dim=0)
|
| 41 |
+
|
| 42 |
+
def forward(self, T_t, S_t, t_att, s_att):
|
| 43 |
+
s_next = self.conv_s_next(S_t)
|
| 44 |
+
t_next = self.conv_t_next(T_t)
|
| 45 |
+
weights_list = []
|
| 46 |
+
for i in range(self.tau):
|
| 47 |
+
weights_list.append((s_att[i] * s_next).sum(dim=(1, 2, 3)) / math.sqrt(self.d))
|
| 48 |
+
weights_list = torch.stack(weights_list, dim=0)
|
| 49 |
+
weights_list = torch.reshape(weights_list, (*weights_list.shape, 1, 1, 1))
|
| 50 |
+
weights_list = self.softmax(weights_list)
|
| 51 |
+
T_trend = t_att * weights_list
|
| 52 |
+
T_trend = T_trend.sum(dim=0)
|
| 53 |
+
t_att_gate = torch.sigmoid(t_next)
|
| 54 |
+
T_fusion = T_t * t_att_gate + (1 - t_att_gate) * T_trend
|
| 55 |
+
T_concat = self.conv_t(T_fusion)
|
| 56 |
+
S_concat = self.conv_s(S_t)
|
| 57 |
+
t_g, t_t, t_s = torch.split(T_concat, self.num_hidden, dim=1)
|
| 58 |
+
s_g, s_t, s_s = torch.split(S_concat, self.num_hidden, dim=1)
|
| 59 |
+
T_gate = torch.sigmoid(t_g)
|
| 60 |
+
S_gate = torch.sigmoid(s_g)
|
| 61 |
+
T_new = T_gate * t_t + (1 - T_gate) * s_t
|
| 62 |
+
S_new = S_gate * s_s + (1 - S_gate) * t_s
|
| 63 |
+
|
| 64 |
+
if self.cell_mode == 'residual':
|
| 65 |
+
S_new = S_new + S_t
|
| 66 |
+
return T_new, S_new
|
utilpack/mim_modules.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MIMBlock(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
|
| 8 |
+
super(MIMBlock, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.convlstm_c = None
|
| 11 |
+
self.num_hidden = num_hidden
|
| 12 |
+
self.padding = filter_size // 2
|
| 13 |
+
self._forget_bias = 1.0
|
| 14 |
+
|
| 15 |
+
self.ct_weight = nn.Parameter(torch.zeros(num_hidden*2, height, width))
|
| 16 |
+
self.oc_weight = nn.Parameter(torch.zeros(num_hidden, height, width))
|
| 17 |
+
|
| 18 |
+
if layer_norm:
|
| 19 |
+
self.conv_t_cc = nn.Sequential(
|
| 20 |
+
nn.Conv2d(in_channel, num_hidden * 3, kernel_size=filter_size,
|
| 21 |
+
stride=stride, padding=self.padding, bias=False),
|
| 22 |
+
nn.LayerNorm([num_hidden * 3, height, width])
|
| 23 |
+
)
|
| 24 |
+
self.conv_s_cc = nn.Sequential(
|
| 25 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 26 |
+
stride=stride, padding=self.padding, bias=False),
|
| 27 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 28 |
+
)
|
| 29 |
+
self.conv_x_cc = nn.Sequential(
|
| 30 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 31 |
+
stride=stride, padding=self.padding, bias=False),
|
| 32 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 33 |
+
)
|
| 34 |
+
self.conv_h_concat = nn.Sequential(
|
| 35 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 36 |
+
stride=stride, padding=self.padding, bias=False),
|
| 37 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 38 |
+
)
|
| 39 |
+
self.conv_x_concat = nn.Sequential(
|
| 40 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 41 |
+
stride=stride, padding=self.padding, bias=False),
|
| 42 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
self.conv_t_cc = nn.Sequential(
|
| 46 |
+
nn.Conv2d(in_channel, num_hidden * 3, kernel_size=filter_size,
|
| 47 |
+
stride=stride, padding=self.padding, bias=False),
|
| 48 |
+
)
|
| 49 |
+
self.conv_s_cc = nn.Sequential(
|
| 50 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 51 |
+
stride=stride, padding=self.padding, bias=False),
|
| 52 |
+
)
|
| 53 |
+
self.conv_x_cc = nn.Sequential(
|
| 54 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 55 |
+
stride=stride, padding=self.padding, bias=False),
|
| 56 |
+
)
|
| 57 |
+
self.conv_h_concat = nn.Sequential(
|
| 58 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 59 |
+
stride=stride, padding=self.padding, bias=False),
|
| 60 |
+
)
|
| 61 |
+
self.conv_x_concat = nn.Sequential(
|
| 62 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 63 |
+
stride=stride, padding=self.padding, bias=False),
|
| 64 |
+
)
|
| 65 |
+
self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
|
| 66 |
+
stride=1, padding=0, bias=False)
|
| 67 |
+
|
| 68 |
+
def _init_state(self, inputs):
|
| 69 |
+
return torch.zeros_like(inputs)
|
| 70 |
+
|
| 71 |
+
def MIMS(self, x, h_t, c_t):
|
| 72 |
+
if h_t is None:
|
| 73 |
+
h_t = self._init_state(x)
|
| 74 |
+
if c_t is None:
|
| 75 |
+
c_t = self._init_state(x)
|
| 76 |
+
|
| 77 |
+
h_concat = self.conv_h_concat(h_t)
|
| 78 |
+
i_h, g_h, f_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
|
| 79 |
+
|
| 80 |
+
ct_activation = torch.mul(c_t.repeat(1,2,1,1), self.ct_weight)
|
| 81 |
+
i_c, f_c = torch.split(ct_activation, self.num_hidden, dim=1)
|
| 82 |
+
|
| 83 |
+
i_ = i_h + i_c
|
| 84 |
+
f_ = f_h + f_c
|
| 85 |
+
g_ = g_h
|
| 86 |
+
o_ = o_h
|
| 87 |
+
|
| 88 |
+
if x != None:
|
| 89 |
+
x_concat = self.conv_x_concat(x)
|
| 90 |
+
i_x, g_x, f_x, o_x = torch.split(x_concat, self.num_hidden, dim=1)
|
| 91 |
+
|
| 92 |
+
i_ = i_ + i_x
|
| 93 |
+
f_ = f_ + f_x
|
| 94 |
+
g_ = g_ + g_x
|
| 95 |
+
o_ = o_ + o_x
|
| 96 |
+
|
| 97 |
+
i_ = torch.sigmoid(i_)
|
| 98 |
+
f_ = torch.sigmoid(f_ + self._forget_bias)
|
| 99 |
+
c_new = f_ * c_t + i_ * torch.tanh(g_)
|
| 100 |
+
|
| 101 |
+
o_c = torch.mul(c_new, self.oc_weight)
|
| 102 |
+
|
| 103 |
+
h_new = torch.sigmoid(o_ + o_c) * torch.tanh(c_new)
|
| 104 |
+
|
| 105 |
+
return h_new, c_new
|
| 106 |
+
|
| 107 |
+
def forward(self, x, diff_h, h, c, m):
|
| 108 |
+
h = self._init_state(x) if h is None else h
|
| 109 |
+
c = self._init_state(x) if c is None else c
|
| 110 |
+
m = self._init_state(x) if m is None else m
|
| 111 |
+
diff_h = self._init_state(x) if diff_h is None else diff_h
|
| 112 |
+
|
| 113 |
+
t_cc = self.conv_t_cc(h)
|
| 114 |
+
s_cc = self.conv_s_cc(m)
|
| 115 |
+
x_cc = self.conv_x_cc(x)
|
| 116 |
+
|
| 117 |
+
i_s, g_s, f_s, o_s = torch.split(s_cc, self.num_hidden, dim=1)
|
| 118 |
+
i_t, g_t, o_t = torch.split(t_cc, self.num_hidden, dim=1)
|
| 119 |
+
i_x, g_x, f_x, o_x = torch.split(x_cc, self.num_hidden, dim=1)
|
| 120 |
+
|
| 121 |
+
i = torch.sigmoid(i_x + i_t)
|
| 122 |
+
i_ = torch.sigmoid(i_x + i_s)
|
| 123 |
+
g = torch.tanh(g_x + g_t)
|
| 124 |
+
g_ = torch.tanh(g_x + g_s)
|
| 125 |
+
f_ = torch.sigmoid(f_x + f_s + self._forget_bias)
|
| 126 |
+
o = torch.sigmoid(o_x + o_t + o_s)
|
| 127 |
+
new_m = f_ * m + i_ * g_
|
| 128 |
+
|
| 129 |
+
c, self.convlstm_c = self.MIMS(diff_h, c, self.convlstm_c \
|
| 130 |
+
if self.convlstm_c is None else self.convlstm_c.detach())
|
| 131 |
+
|
| 132 |
+
new_c = c + i * g
|
| 133 |
+
cell = torch.cat((new_c, new_m), 1)
|
| 134 |
+
new_h = o * torch.tanh(self.conv_last(cell))
|
| 135 |
+
|
| 136 |
+
return new_h, new_c, new_m
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class MIMN(nn.Module):
|
| 140 |
+
|
| 141 |
+
def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
|
| 142 |
+
super(MIMN, self).__init__()
|
| 143 |
+
|
| 144 |
+
self.num_hidden = num_hidden
|
| 145 |
+
self.padding = filter_size // 2
|
| 146 |
+
self._forget_bias = 1.0
|
| 147 |
+
|
| 148 |
+
self.ct_weight = nn.Parameter(torch.zeros(num_hidden*2, height, width))
|
| 149 |
+
self.oc_weight = nn.Parameter(torch.zeros(num_hidden, height, width))
|
| 150 |
+
|
| 151 |
+
if layer_norm:
|
| 152 |
+
self.conv_h_concat = nn.Sequential(
|
| 153 |
+
nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
|
| 154 |
+
stride=stride, padding=self.padding, bias=False),
|
| 155 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 156 |
+
)
|
| 157 |
+
self.conv_x_concat = nn.Sequential(
|
| 158 |
+
nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
|
| 159 |
+
stride=stride, padding=self.padding, bias=False),
|
| 160 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
self.conv_h_concat = nn.Sequential(
|
| 164 |
+
nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
|
| 165 |
+
stride=stride, padding=self.padding, bias=False),
|
| 166 |
+
)
|
| 167 |
+
self.conv_x_concat = nn.Sequential(
|
| 168 |
+
nn.Conv2d(in_channel, num_hidden * 4, kernel_size=filter_size,
|
| 169 |
+
stride=stride, padding=self.padding, bias=False),
|
| 170 |
+
)
|
| 171 |
+
self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
|
| 172 |
+
stride=1, padding=0, bias=False)
|
| 173 |
+
|
| 174 |
+
def _init_state(self, inputs):
|
| 175 |
+
return torch.zeros_like(inputs)
|
| 176 |
+
|
| 177 |
+
def forward(self, x, h_t, c_t):
|
| 178 |
+
if h_t is None:
|
| 179 |
+
h_t = self._init_state(x)
|
| 180 |
+
if c_t is None:
|
| 181 |
+
c_t = self._init_state(x)
|
| 182 |
+
|
| 183 |
+
h_concat = self.conv_h_concat(h_t)
|
| 184 |
+
i_h, g_h, f_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
|
| 185 |
+
|
| 186 |
+
ct_activation = torch.mul(c_t.repeat(1,2,1,1), self.ct_weight)
|
| 187 |
+
i_c, f_c = torch.split(ct_activation, self.num_hidden, dim=1)
|
| 188 |
+
|
| 189 |
+
i_ = i_h + i_c
|
| 190 |
+
f_ = f_h + f_c
|
| 191 |
+
g_ = g_h
|
| 192 |
+
o_ = o_h
|
| 193 |
+
|
| 194 |
+
if x != None:
|
| 195 |
+
x_concat = self.conv_x_concat(x)
|
| 196 |
+
i_x, g_x, f_x, o_x = torch.split(x_concat, self.num_hidden, dim=1)
|
| 197 |
+
|
| 198 |
+
i_ = i_ + i_x
|
| 199 |
+
f_ = f_ + f_x
|
| 200 |
+
g_ = g_ + g_x
|
| 201 |
+
o_ = o_ + o_x
|
| 202 |
+
|
| 203 |
+
i_ = torch.sigmoid(i_)
|
| 204 |
+
f_ = torch.sigmoid(f_ + self._forget_bias)
|
| 205 |
+
c_new = f_ * c_t + i_ * torch.tanh(g_)
|
| 206 |
+
|
| 207 |
+
o_c = torch.mul(c_new, self.oc_weight)
|
| 208 |
+
|
| 209 |
+
h_new = torch.sigmoid(o_ + o_c) * torch.tanh(c_new)
|
| 210 |
+
|
| 211 |
+
return h_new, c_new
|
utilpack/mmvp_modules.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class ResidualDenseBlock_4C(nn.Module):
|
| 6 |
+
def __init__(self, nf=64, gc = 32, bias=True):
|
| 7 |
+
super(ResidualDenseBlock_4C, self).__init__()
|
| 8 |
+
# gc: growth channel, i.e. intermediate channels
|
| 9 |
+
|
| 10 |
+
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
| 11 |
+
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
| 12 |
+
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
| 13 |
+
self.conv4 = nn.Conv2d(nf + 3 * gc, nf, 3, 1, 1, bias=bias)
|
| 14 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 15 |
+
|
| 16 |
+
# initialization
|
| 17 |
+
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x1 = self.lrelu(self.conv1(x))
|
| 21 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
| 22 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
| 23 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
| 24 |
+
return x4 * 0.2 + x
|
| 25 |
+
|
| 26 |
+
class RRDB(nn.Module):
|
| 27 |
+
'''Residual in Residual Dense Block'''
|
| 28 |
+
|
| 29 |
+
def __init__(self, nf):
|
| 30 |
+
super(RRDB, self).__init__()
|
| 31 |
+
gc = nf // 2
|
| 32 |
+
self.RDB1 = ResidualDenseBlock_4C(nf, gc)
|
| 33 |
+
self.RDB2 = ResidualDenseBlock_4C(nf, gc)
|
| 34 |
+
self.RDB3 = ResidualDenseBlock_4C(nf, gc)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
out = self.RDB1(x)
|
| 38 |
+
out = self.RDB2(out)
|
| 39 |
+
out = self.RDB3(out)
|
| 40 |
+
return out * 0.2 + x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Up(nn.Module):
|
| 45 |
+
"""Upscaling then double conv"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, in_channels, out_channels, bilinear=True, skip=True, scale=2, bn=True, motion=False):
|
| 48 |
+
super().__init__()
|
| 49 |
+
factor = scale
|
| 50 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
| 51 |
+
if bilinear:
|
| 52 |
+
if skip:
|
| 53 |
+
self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True)
|
| 54 |
+
self.conv = ConvLayer(in_channels, out_channels, bn=bn)
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True)
|
| 58 |
+
self.conv = ConvLayer(in_channels, out_channels)
|
| 59 |
+
else:
|
| 60 |
+
if skip:
|
| 61 |
+
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor)
|
| 62 |
+
self.conv = ConvLayer(out_channels*2, out_channels, bn=bn, motion=motion)
|
| 63 |
+
else:
|
| 64 |
+
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor)
|
| 65 |
+
self.conv = ConvLayer(out_channels, out_channels, bn=bn, motion=motion)
|
| 66 |
+
|
| 67 |
+
def forward(self, x1, x2=None):
|
| 68 |
+
|
| 69 |
+
x1 = self.up(x1)
|
| 70 |
+
if x2 is None:
|
| 71 |
+
return self.conv(x1)
|
| 72 |
+
# input is CHW
|
| 73 |
+
diffY = x2.size()[2] - x1.size()[2]
|
| 74 |
+
diffX = x2.size()[3] - x1.size()[3]
|
| 75 |
+
|
| 76 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
| 77 |
+
diffY // 2, diffY - diffY // 2])
|
| 78 |
+
# if you have padding issues, see
|
| 79 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
| 80 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
| 81 |
+
x = torch.cat([x2, x1], dim=1)
|
| 82 |
+
return self.conv(x)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ResBlock(nn.Module):
|
| 86 |
+
def __init__(self, in_channels, out_channels, downsample=False,
|
| 87 |
+
upsample=False, skip=False, factor=2, motion=False):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.upsample = upsample
|
| 90 |
+
self.maxpool= None
|
| 91 |
+
if downsample:
|
| 92 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 93 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2)
|
| 94 |
+
if factor == 4:
|
| 95 |
+
self.maxpool = nn.MaxPool2d(2)
|
| 96 |
+
|
| 97 |
+
elif upsample:
|
| 98 |
+
self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor)
|
| 99 |
+
|
| 100 |
+
if motion:
|
| 101 |
+
self.shortcut = nn.Sequential(nn.Upsample(scale_factor=factor,
|
| 102 |
+
mode='bilinear',
|
| 103 |
+
align_corners=True),
|
| 104 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1),
|
| 105 |
+
nn.BatchNorm2d(out_channels))
|
| 106 |
+
else:
|
| 107 |
+
self.shortcut = nn.Sequential(nn.Upsample(scale_factor=factor,
|
| 108 |
+
mode='bilinear',
|
| 109 |
+
align_corners=True),
|
| 110 |
+
nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=1))
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 114 |
+
self.shortcut = nn.Sequential()
|
| 115 |
+
|
| 116 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 117 |
+
|
| 118 |
+
def forward(self, input):
|
| 119 |
+
shortcut = self.shortcut(input)
|
| 120 |
+
input = nn.ReLU()(self.conv1(input))
|
| 121 |
+
input = nn.ReLU()(self.conv2(input))
|
| 122 |
+
input = input + shortcut
|
| 123 |
+
if self.maxpool is not None:
|
| 124 |
+
input = self.maxpool(input)
|
| 125 |
+
return nn.LeakyReLU()(input)
|
| 126 |
+
|
| 127 |
+
class ConvLayer(nn.Module):
|
| 128 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, in_channels, out_channels, mid_channels=None, bn=True, motion=False, dilation=1):
|
| 131 |
+
super().__init__()
|
| 132 |
+
if not mid_channels:
|
| 133 |
+
mid_channels = out_channels
|
| 134 |
+
|
| 135 |
+
self.conv = nn.Sequential(
|
| 136 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 137 |
+
nn.BatchNorm2d(out_channels),
|
| 138 |
+
nn.ReLU(inplace=True),
|
| 139 |
+
) if motion else nn.Sequential(
|
| 140 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation, bias=False, dilation=dilation),
|
| 141 |
+
nn.ReLU(inplace=True)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
return self.conv(x)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Conv3D(nn.Module):
|
| 149 |
+
def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
|
| 150 |
+
super(Conv3D, self).__init__()
|
| 151 |
+
self.conv3d = nn.Conv3d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)
|
| 152 |
+
self.bn3d = nn.BatchNorm3d(out_channel)
|
| 153 |
+
|
| 154 |
+
def forward(self, x):
|
| 155 |
+
# input x: (batch, seq, c, h, w)
|
| 156 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous() # (batch, c, seq_len, h, w)
|
| 157 |
+
x = F.leaky_relu(self.bn3d(self.conv3d(x)))
|
| 158 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous() # (batch, seq_len, c, h, w)
|
| 159 |
+
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
class MatrixPredictor3DConv(nn.Module):
|
| 163 |
+
def __init__(self, hidden_len=64):
|
| 164 |
+
super(MatrixPredictor3DConv, self).__init__()
|
| 165 |
+
self.unet_base = hidden_len #64
|
| 166 |
+
self.hidden_len = hidden_len #64
|
| 167 |
+
self.conv_pre_1 = nn.Conv2d(hidden_len,hidden_len, kernel_size=3, stride=1, padding=1)
|
| 168 |
+
self.conv_pre_2 = nn.Conv2d(hidden_len, hidden_len, kernel_size=3, stride=1, padding=1)
|
| 169 |
+
|
| 170 |
+
self.conv3d_1 = Conv3D(self.unet_base, self.unet_base, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1))
|
| 171 |
+
self.conv3d_2 = Conv3D(self.unet_base*2, self.unet_base*2, kernel_size=(3 , 3, 3), stride=1, padding=(0, 1, 1))
|
| 172 |
+
|
| 173 |
+
self.conv1_1 = nn.Conv2d(hidden_len, self.unet_base, kernel_size=3, stride=2, padding=1)
|
| 174 |
+
self.conv2_1 = nn.Conv2d(self.unet_base, self.unet_base * 2, kernel_size=3, stride=2, padding=1)
|
| 175 |
+
|
| 176 |
+
self.conv3_1 = nn.Conv2d(self.unet_base * 3, self.unet_base, kernel_size=3, stride=1, padding=1)
|
| 177 |
+
self.conv4_1 = nn.Conv2d(self.unet_base, self.hidden_len, kernel_size=3, stride=1, padding=1)
|
| 178 |
+
|
| 179 |
+
self.bn_pre_1 = nn.BatchNorm2d(hidden_len)
|
| 180 |
+
self.bn_pre_2 = nn.BatchNorm2d(hidden_len)
|
| 181 |
+
self.bn1_1 = nn.BatchNorm2d(self.unet_base)
|
| 182 |
+
self.bn2_1 = nn.BatchNorm2d(self.unet_base * 2)
|
| 183 |
+
self.bn3_1 = nn.BatchNorm2d(self.unet_base)
|
| 184 |
+
self.bn4_1 = nn.BatchNorm2d(self.hidden_len)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def forward(self,x):
|
| 188 |
+
# x [B,T,C,32,32]
|
| 189 |
+
# out: [B,C,32,32]
|
| 190 |
+
batch, seq, z, h, w = x.size()
|
| 191 |
+
x = x.reshape(-1, x.size(-3), x.size(-2), x.size(-1))
|
| 192 |
+
x = F.leaky_relu(self.bn_pre_1(self.conv_pre_1(x)))
|
| 193 |
+
x = F.leaky_relu(self.bn_pre_2(self.conv_pre_2(x)))
|
| 194 |
+
x_1 = F.leaky_relu(self.bn1_1(self.conv1_1(x)))
|
| 195 |
+
|
| 196 |
+
x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)).contiguous() # (batch, seq, c, h, w)
|
| 197 |
+
x_1 = self.conv3d_1(x_1) # (batch, seq, c, h, w), 1st temporal conv
|
| 198 |
+
x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() # (batch * seq, c, h, w)
|
| 199 |
+
x_2 = F.leaky_relu(self.bn2_1(self.conv2_1(x_1))) # (batch * seq, c, h // 2, w // 2)
|
| 200 |
+
x_2 = x_2.view(batch, -1, x_2.size(1), x_2.size(2), x_2.size(3)).contiguous() # (batch, seq, c, h, w)
|
| 201 |
+
x_2 = self.conv3d_2(x_2) # (batch, seq=1, c, h // 2, w // 2), 2nd temporal conv
|
| 202 |
+
x_2 = x_2.view(-1, x_2.size(2), x_2.size(3), x_2.size(4)).contiguous() # (batch * seq, c, h//2, w//2), seq = 1
|
| 203 |
+
|
| 204 |
+
x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)) # (batch, seq, c, h, w)
|
| 205 |
+
x_1 = x_1.permute(0, 2, 1, 3, 4).contiguous() # (batch, c, seq, h, w)
|
| 206 |
+
x_1 = F.adaptive_max_pool3d(x_1, (1, None, None)) # (batch, c, 1, h, w)
|
| 207 |
+
x_1 = x_1.permute(0, 2, 1, 3, 4).contiguous() # (batch, 1, c, h, w)
|
| 208 |
+
x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() # (batch*1, c, h, w)
|
| 209 |
+
x_3 = F.leaky_relu(self.bn3_1(self.conv3_1(torch.cat((F.interpolate(x_2, scale_factor=(2, 2)), x_1), dim=1))))
|
| 210 |
+
x = x.view(batch, -1, x.size(1), x.size(2), x.size(3)) # (batch, seq, 1, h, w)
|
| 211 |
+
x = F.leaky_relu(self.bn4_1(self.conv4_1(F.interpolate(x_3, scale_factor=(2, 2)))))
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
class SimpleMatrixPredictor3DConv_direct(nn.Module):
|
| 215 |
+
def __init__(self, T, hidden_len=64, image_pred=False, aft_seq_length=10):
|
| 216 |
+
super(SimpleMatrixPredictor3DConv_direct, self).__init__()
|
| 217 |
+
self.unet_base = hidden_len #64
|
| 218 |
+
self.hidden_len = hidden_len #64
|
| 219 |
+
self.conv_pre_1 = nn.Conv2d(hidden_len,hidden_len, kernel_size=3, stride=1, padding=1)
|
| 220 |
+
self.conv_pre_2 = nn.Conv2d(hidden_len, hidden_len, kernel_size=3, stride=1, padding=1)
|
| 221 |
+
self.fut_len = aft_seq_length
|
| 222 |
+
|
| 223 |
+
self.conv3d_1 = Conv3D(self.unet_base, self.unet_base, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1))
|
| 224 |
+
|
| 225 |
+
if self.fut_len > 1 :
|
| 226 |
+
self.temporal_layer = Conv3D(self.unet_base*2, self.unet_base*2, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1))
|
| 227 |
+
else:
|
| 228 |
+
self.temporal_layer = nn.Sequential(
|
| 229 |
+
nn.Conv2d(self.unet_base *2, self.unet_base * 2, kernel_size=3, stride=1, padding=1),
|
| 230 |
+
nn.LeakyReLU())
|
| 231 |
+
|
| 232 |
+
input_len = T if image_pred else T - 1
|
| 233 |
+
self.conv_translate = nn.Sequential(
|
| 234 |
+
nn.Conv2d(self.unet_base * input_len , self.unet_base * self.fut_len, kernel_size=1, stride=1, padding=0),
|
| 235 |
+
nn.LeakyReLU())
|
| 236 |
+
|
| 237 |
+
self.conv1_1 = nn.Conv2d(hidden_len, self.unet_base, kernel_size=3, stride=2, padding=1)
|
| 238 |
+
self.conv2_1 = nn.Conv2d(self.unet_base, self.unet_base * 2, kernel_size=3, stride=2, padding=1)
|
| 239 |
+
|
| 240 |
+
self.conv3_1 = nn.Conv2d(self.unet_base * 3, self.unet_base, kernel_size=3, stride=1, padding=1)
|
| 241 |
+
self.conv4_1 = nn.Conv2d(self.unet_base, self.hidden_len, kernel_size=3, stride=1, padding=1)
|
| 242 |
+
|
| 243 |
+
self.bn_pre_1 = nn.BatchNorm2d(hidden_len)
|
| 244 |
+
self.bn_pre_2 = nn.BatchNorm2d(hidden_len)
|
| 245 |
+
self.bn1_1 = nn.BatchNorm2d(self.unet_base)
|
| 246 |
+
self.bn2_1 = nn.BatchNorm2d(self.unet_base * 2)
|
| 247 |
+
self.bn3_1 = nn.BatchNorm2d(self.unet_base)
|
| 248 |
+
self.bn4_1 = nn.BatchNorm2d(self.hidden_len)
|
| 249 |
+
self.bn_translate = nn.BatchNorm2d(self.unet_base * self.fut_len)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def forward(self,x):
|
| 253 |
+
# x [B,T,C,32,32]
|
| 254 |
+
# out: [B,C,32,32]
|
| 255 |
+
batch, seq, z, h, w = x.size()
|
| 256 |
+
x = x.reshape(-1, x.size(-3), x.size(-2), x.size(-1))
|
| 257 |
+
x = F.leaky_relu(self.bn_pre_1(self.conv_pre_1(x)))
|
| 258 |
+
x = F.leaky_relu(self.bn_pre_2(self.conv_pre_2(x)))
|
| 259 |
+
x_1 = F.leaky_relu(self.bn1_1(self.conv1_1(x)))
|
| 260 |
+
|
| 261 |
+
x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)).contiguous() # (batch, seq, c, h, w)
|
| 262 |
+
|
| 263 |
+
x_1 = self.conv3d_1(x_1) # (batch, seq, c, h, w), 1st temporal conv
|
| 264 |
+
batch, seq, c, h, w = x_1.shape
|
| 265 |
+
x_tmp = x_1.reshape(batch,-1,h,w)
|
| 266 |
+
x_tmp = self.bn_translate(self.conv_translate(x_tmp))
|
| 267 |
+
x_1 = x_tmp.reshape(batch,self.fut_len,c,h,w)
|
| 268 |
+
x_1 = x_1.view(-1, x_1.size(2), x_1.size(3), x_1.size(4)).contiguous() # (batch * seq, c, h, w)
|
| 269 |
+
x_2 = F.leaky_relu(self.bn2_1(self.conv2_1(x_1))) # (batch * seq, c, h // 2, w // 2)
|
| 270 |
+
if self.fut_len > 1:
|
| 271 |
+
x_2 = x_2.view(batch, -1, x_2.size(1), x_2.size(2), x_2.size(3)).contiguous() # (batch, seq, c, h, w)
|
| 272 |
+
x_2 = self.temporal_layer(x_2) # (batch, seq=10, c, h // 2, w // 2)
|
| 273 |
+
|
| 274 |
+
x_2 = x_2.view(-1, x_2.size(2), x_2.size(3), x_2.size(4)).contiguous() # (batch * seq, c, h//2, w//2), seq = 1
|
| 275 |
+
else:
|
| 276 |
+
x_2 = self.temporal_layer(x_2) # (batch * seq,c, h // 2, w // 2)
|
| 277 |
+
|
| 278 |
+
x_1 = x_1.view(batch, -1, x_1.size(1), x_1.size(2), x_1.size(3)) # (batch, seq, c, h, w)
|
| 279 |
+
|
| 280 |
+
x_1 = x_1.reshape(-1, x_1.size(2), x_1.size(3), x_1.size(4))
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
x_3 = F.leaky_relu(self.bn3_1(self.conv3_1(torch.cat((F.interpolate(x_2, size=x_1.shape[2:]), x_1), dim=1))))
|
| 284 |
+
x = x.view(batch, -1, x.size(1), x.size(2), x.size(3)) # (batch, seq, 1, h, w)
|
| 285 |
+
x = F.leaky_relu(self.bn4_1(self.conv4_1(F.interpolate(x_3, size = x.shape[3:]))))
|
| 286 |
+
|
| 287 |
+
return x
|
| 288 |
+
|
| 289 |
+
class PredictModel(nn.Module):
|
| 290 |
+
def __init__(self, T, hidden_len=32, aft_seq_length=10, mx_h=32, mx_w=32, use_direct_predictor=True):
|
| 291 |
+
super(PredictModel, self).__init__()
|
| 292 |
+
self.mx_h = mx_h
|
| 293 |
+
self.mx_w = mx_w
|
| 294 |
+
self.hidden_len = hidden_len
|
| 295 |
+
self.fut_len = aft_seq_length
|
| 296 |
+
self.conv1 = nn.Conv2d( 1, hidden_len, kernel_size=3, padding=1, bias=False)
|
| 297 |
+
self.fuse_conv = nn.Conv2d(hidden_len*2, hidden_len, kernel_size=3, padding=1, bias=False)
|
| 298 |
+
if use_direct_predictor:
|
| 299 |
+
self.predictor = SimpleMatrixPredictor3DConv_direct(T=T, hidden_len=hidden_len, aft_seq_length=aft_seq_length)
|
| 300 |
+
else:
|
| 301 |
+
self.predictor = MatrixPredictor3DConv(hidden_len)
|
| 302 |
+
self.out_conv = nn.Conv2d(hidden_len, 1, kernel_size=3, padding=1, bias=False)
|
| 303 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 304 |
+
self.sigmoid = nn.Sigmoid()
|
| 305 |
+
|
| 306 |
+
def res_interpolate(self,in_tensor,template_tensor):
|
| 307 |
+
'''
|
| 308 |
+
in_tensor: batch,c,h'w',H'W'
|
| 309 |
+
tempolate_tensor: batch,c,hw,HW
|
| 310 |
+
out_tensor: batch,c,hw,HW
|
| 311 |
+
'''
|
| 312 |
+
out_tensor = F.interpolate(in_tensor,template_tensor.shape[-2:]) # (BThw,target_h,target_w)
|
| 313 |
+
|
| 314 |
+
return out_tensor
|
| 315 |
+
|
| 316 |
+
def forward(self,matrix_seq, softmax=False, res=None):
|
| 317 |
+
|
| 318 |
+
B,T,hw,window_size = matrix_seq.size()
|
| 319 |
+
|
| 320 |
+
matrix_seq = matrix_seq.reshape(-1,hw,self.mx_h,self.mx_w) # (BT,hw,hw)
|
| 321 |
+
matrix_seq = matrix_seq.reshape(B*T*hw,self.mx_h,self.mx_w).unsqueeze(1) # (BThw,1,h,w)
|
| 322 |
+
|
| 323 |
+
x = self.conv1(matrix_seq)
|
| 324 |
+
x = x.reshape(B,T,hw,-1,self.mx_h,self.mx_w)
|
| 325 |
+
x = x.permute(0,2,1,3,4,5).reshape(B*hw,T,-1,self.mx_h,self.mx_w)
|
| 326 |
+
emb = self.predictor(x)
|
| 327 |
+
|
| 328 |
+
emb = emb.reshape(B*hw*self.fut_len,-1,self.mx_h,self.mx_w)
|
| 329 |
+
res_emb = emb.clone()
|
| 330 |
+
if res is not None:
|
| 331 |
+
template = emb.clone().reshape(B,hw,emb.shape[1],-1).permute(0,2,1,3)
|
| 332 |
+
in_tensor = res.clone().reshape(B,hw//4,emb.shape[1],-1).permute(0,2,1,3)
|
| 333 |
+
|
| 334 |
+
res_tensor = self.res_interpolate(in_tensor,template).permute(0,2,1,3).reshape(emb.shape)
|
| 335 |
+
|
| 336 |
+
emb = self.fuse_conv(torch.cat([emb,res_tensor],dim=1))
|
| 337 |
+
|
| 338 |
+
out = self.out_conv(emb) #(Bhwt,16,h//4,w//4)
|
| 339 |
+
|
| 340 |
+
out = out.reshape(B,hw,-1,self.mx_h,self.mx_w)
|
| 341 |
+
out = out.permute(0,2,1,3,4)
|
| 342 |
+
out = out.reshape(B,-1,hw,window_size)
|
| 343 |
+
|
| 344 |
+
if softmax:
|
| 345 |
+
out = out.view(B,out.shape[1],-1)
|
| 346 |
+
out = self.softmax(out)
|
| 347 |
+
out = out.reshape(B,-1,hw,window_size)
|
| 348 |
+
|
| 349 |
+
return out,res_emb
|
utilpack/phydnet_modules.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from numpy import *
|
| 4 |
+
from numpy.linalg import *
|
| 5 |
+
from scipy.special import factorial
|
| 6 |
+
from functools import reduce
|
| 7 |
+
|
| 8 |
+
__all__ = ['M2K','K2M']
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PhyCell_Cell(nn.Module):
|
| 12 |
+
|
| 13 |
+
def __init__(self, input_dim, F_hidden_dim, kernel_size, bias=1):
|
| 14 |
+
super(PhyCell_Cell, self).__init__()
|
| 15 |
+
self.input_dim = input_dim
|
| 16 |
+
self.F_hidden_dim = F_hidden_dim
|
| 17 |
+
self.kernel_size = kernel_size
|
| 18 |
+
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
|
| 19 |
+
self.bias = bias
|
| 20 |
+
|
| 21 |
+
self.F = nn.Sequential()
|
| 22 |
+
self.F.add_module('conv1', nn.Conv2d(in_channels=input_dim, out_channels=F_hidden_dim,
|
| 23 |
+
kernel_size=self.kernel_size, stride=(1,1), padding=self.padding))
|
| 24 |
+
self.F.add_module('bn1',nn.GroupNorm(7 ,F_hidden_dim))
|
| 25 |
+
self.F.add_module('conv2', nn.Conv2d(in_channels=F_hidden_dim, out_channels=input_dim,
|
| 26 |
+
kernel_size=(1,1), stride=(1,1), padding=(0,0)))
|
| 27 |
+
|
| 28 |
+
self.convgate = nn.Conv2d(in_channels=self.input_dim + self.input_dim,
|
| 29 |
+
out_channels=self.input_dim,
|
| 30 |
+
kernel_size=(3,3),
|
| 31 |
+
padding=(1,1), bias=self.bias)
|
| 32 |
+
|
| 33 |
+
def forward(self, x, hidden): # x [batch_size, hidden_dim, height, width]
|
| 34 |
+
combined = torch.cat([x, hidden], dim=1) # concatenate along channel axis
|
| 35 |
+
combined_conv = self.convgate(combined)
|
| 36 |
+
K = torch.sigmoid(combined_conv)
|
| 37 |
+
hidden_tilde = hidden + self.F(hidden) # prediction
|
| 38 |
+
next_hidden = hidden_tilde + K * (x-hidden_tilde) # correction , Haddamard product
|
| 39 |
+
return next_hidden
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class PhyCell(nn.Module):
|
| 43 |
+
|
| 44 |
+
def __init__(self, input_shape, input_dim, F_hidden_dims, n_layers, kernel_size, device):
|
| 45 |
+
super(PhyCell, self).__init__()
|
| 46 |
+
self.input_shape = input_shape
|
| 47 |
+
self.input_dim = input_dim
|
| 48 |
+
self.F_hidden_dims = F_hidden_dims
|
| 49 |
+
self.n_layers = n_layers
|
| 50 |
+
self.kernel_size = kernel_size
|
| 51 |
+
self.H = []
|
| 52 |
+
self.device = device
|
| 53 |
+
|
| 54 |
+
cell_list = []
|
| 55 |
+
for i in range(0, self.n_layers):
|
| 56 |
+
cell_list.append(PhyCell_Cell(input_dim=input_dim,
|
| 57 |
+
F_hidden_dim=self.F_hidden_dims[i],
|
| 58 |
+
kernel_size=self.kernel_size))
|
| 59 |
+
self.cell_list = nn.ModuleList(cell_list)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]
|
| 63 |
+
batch_size = input_.data.size()[0]
|
| 64 |
+
if (first_timestep):
|
| 65 |
+
self.initHidden(batch_size) # init Hidden at each forward start
|
| 66 |
+
for j, cell in enumerate(self.cell_list):
|
| 67 |
+
self.H[j] = self.H[j].to(input_.device)
|
| 68 |
+
if j==0: # bottom layer
|
| 69 |
+
self.H[j] = cell(input_, self.H[j])
|
| 70 |
+
else:
|
| 71 |
+
self.H[j] = cell(self.H[j-1],self.H[j])
|
| 72 |
+
return self.H, self.H
|
| 73 |
+
|
| 74 |
+
def initHidden(self, batch_size):
|
| 75 |
+
self.H = []
|
| 76 |
+
for i in range(self.n_layers):
|
| 77 |
+
self.H.append(torch.zeros(
|
| 78 |
+
batch_size, self.input_dim, self.input_shape[0], self.input_shape[1]).to(self.device))
|
| 79 |
+
|
| 80 |
+
def setHidden(self, H):
|
| 81 |
+
self.H = H
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class PhyD_ConvLSTM_Cell(nn.Module):
|
| 85 |
+
def __init__(self, input_shape, input_dim, hidden_dim, kernel_size, bias=1):
|
| 86 |
+
"""
|
| 87 |
+
input_shape: (int, int)
|
| 88 |
+
Height and width of input tensor as (height, width).
|
| 89 |
+
input_dim: int
|
| 90 |
+
Number of channels of input tensor.
|
| 91 |
+
hidden_dim: int
|
| 92 |
+
Number of channels of hidden state.
|
| 93 |
+
kernel_size: (int, int)
|
| 94 |
+
Size of the convolutional kernel.
|
| 95 |
+
bias: bool
|
| 96 |
+
Whether or not to add the bias.
|
| 97 |
+
"""
|
| 98 |
+
super(PhyD_ConvLSTM_Cell, self).__init__()
|
| 99 |
+
|
| 100 |
+
self.height, self.width = input_shape
|
| 101 |
+
self.input_dim = input_dim
|
| 102 |
+
self.hidden_dim = hidden_dim
|
| 103 |
+
self.kernel_size = kernel_size
|
| 104 |
+
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
|
| 105 |
+
self.bias = bias
|
| 106 |
+
|
| 107 |
+
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
|
| 108 |
+
out_channels=4 * self.hidden_dim,
|
| 109 |
+
kernel_size=self.kernel_size,
|
| 110 |
+
padding=self.padding, bias=self.bias)
|
| 111 |
+
|
| 112 |
+
# we implement LSTM that process only one timestep
|
| 113 |
+
def forward(self, x, hidden): # x [batch, hidden_dim, width, height]
|
| 114 |
+
h_cur, c_cur = hidden
|
| 115 |
+
|
| 116 |
+
combined = torch.cat([x, h_cur], dim=1) # concatenate along channel axis
|
| 117 |
+
combined_conv = self.conv(combined)
|
| 118 |
+
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
|
| 119 |
+
i = torch.sigmoid(cc_i)
|
| 120 |
+
f = torch.sigmoid(cc_f)
|
| 121 |
+
o = torch.sigmoid(cc_o)
|
| 122 |
+
g = torch.tanh(cc_g)
|
| 123 |
+
|
| 124 |
+
c_next = f * c_cur + i * g
|
| 125 |
+
h_next = o * torch.tanh(c_next)
|
| 126 |
+
return h_next, c_next
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class PhyD_ConvLSTM(nn.Module):
|
| 130 |
+
|
| 131 |
+
def __init__(self, input_shape, input_dim, hidden_dims, n_layers, kernel_size, device):
|
| 132 |
+
super(PhyD_ConvLSTM, self).__init__()
|
| 133 |
+
self.input_shape = input_shape
|
| 134 |
+
self.input_dim = input_dim
|
| 135 |
+
self.hidden_dims = hidden_dims
|
| 136 |
+
self.n_layers = n_layers
|
| 137 |
+
self.kernel_size = kernel_size
|
| 138 |
+
self.H, self.C = [], []
|
| 139 |
+
self.device = device
|
| 140 |
+
|
| 141 |
+
cell_list = []
|
| 142 |
+
for i in range(0, self.n_layers):
|
| 143 |
+
cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1]
|
| 144 |
+
print('layer ', i, 'input dim ', cur_input_dim, ' hidden dim ', self.hidden_dims[i])
|
| 145 |
+
cell_list.append(PhyD_ConvLSTM_Cell(input_shape=self.input_shape,
|
| 146 |
+
input_dim=cur_input_dim,
|
| 147 |
+
hidden_dim=self.hidden_dims[i],
|
| 148 |
+
kernel_size=self.kernel_size))
|
| 149 |
+
self.cell_list = nn.ModuleList(cell_list)
|
| 150 |
+
|
| 151 |
+
def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]
|
| 152 |
+
batch_size = input_.data.size()[0]
|
| 153 |
+
if (first_timestep):
|
| 154 |
+
self.initHidden(batch_size) # init Hidden at each forward start
|
| 155 |
+
for j, cell in enumerate(self.cell_list):
|
| 156 |
+
self.H[j] = self.H[j].to(input_.device)
|
| 157 |
+
self.C[j] = self.C[j].to(input_.device)
|
| 158 |
+
if j==0: # bottom layer
|
| 159 |
+
self.H[j], self.C[j] = cell(input_, (self.H[j],self.C[j]))
|
| 160 |
+
else:
|
| 161 |
+
self.H[j], self.C[j] = cell(self.H[j-1],(self.H[j],self.C[j]))
|
| 162 |
+
return (self.H,self.C) , self.H # (hidden, output)
|
| 163 |
+
|
| 164 |
+
def initHidden(self,batch_size):
|
| 165 |
+
self.H, self.C = [],[]
|
| 166 |
+
for i in range(self.n_layers):
|
| 167 |
+
self.H.append(torch.zeros(
|
| 168 |
+
batch_size, self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device))
|
| 169 |
+
self.C.append(torch.zeros(
|
| 170 |
+
batch_size, self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device))
|
| 171 |
+
|
| 172 |
+
def setHidden(self, hidden):
|
| 173 |
+
H,C = hidden
|
| 174 |
+
self.H, self.C = H,C
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class dcgan_conv(nn.Module):
|
| 178 |
+
|
| 179 |
+
def __init__(self, nin, nout, stride):
|
| 180 |
+
super(dcgan_conv, self).__init__()
|
| 181 |
+
self.main = nn.Sequential(
|
| 182 |
+
nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=(3,3),
|
| 183 |
+
stride=stride, padding=1),
|
| 184 |
+
nn.GroupNorm(16, nout),
|
| 185 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def forward(self, input):
|
| 189 |
+
return self.main(input)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class dcgan_upconv(nn.Module):
|
| 193 |
+
|
| 194 |
+
def __init__(self, nin, nout, stride):
|
| 195 |
+
super(dcgan_upconv, self).__init__()
|
| 196 |
+
if stride==2:
|
| 197 |
+
output_padding = 1
|
| 198 |
+
else:
|
| 199 |
+
output_padding = 0
|
| 200 |
+
self.main = nn.Sequential(
|
| 201 |
+
nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=(3,3),
|
| 202 |
+
stride=stride, padding=1, output_padding=output_padding),
|
| 203 |
+
nn.GroupNorm(16, nout),
|
| 204 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def forward(self, input):
|
| 208 |
+
return self.main(input)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class encoder_E(nn.Module):
|
| 212 |
+
|
| 213 |
+
def __init__(self, nc=1, nf=32, patch_size=4):
|
| 214 |
+
super(encoder_E, self).__init__()
|
| 215 |
+
assert patch_size in [2, 4]
|
| 216 |
+
stride_2 = patch_size // 2
|
| 217 |
+
# input is (1) x 64 x 64
|
| 218 |
+
self.c1 = dcgan_conv(nc, nf, stride=2) # (32) x 32 x 32
|
| 219 |
+
self.c2 = dcgan_conv(nf, nf, stride=1) # (32) x 32 x 32
|
| 220 |
+
self.c3 = dcgan_conv(nf, 2*nf, stride=stride_2) # (64) x 16 x 16
|
| 221 |
+
|
| 222 |
+
def forward(self, input):
|
| 223 |
+
h1 = self.c1(input)
|
| 224 |
+
h2 = self.c2(h1)
|
| 225 |
+
h3 = self.c3(h2)
|
| 226 |
+
return h3
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class decoder_D(nn.Module):
|
| 230 |
+
|
| 231 |
+
def __init__(self, nc=1, nf=32, patch_size=4):
|
| 232 |
+
super(decoder_D, self).__init__()
|
| 233 |
+
assert patch_size in [2, 4]
|
| 234 |
+
stride_2 = patch_size // 2
|
| 235 |
+
output_padding = 1 if stride_2==2 else 0
|
| 236 |
+
self.upc1 = dcgan_upconv(2*nf, nf, stride=2) #(32) x 32 x 32
|
| 237 |
+
self.upc2 = dcgan_upconv(nf, nf, stride=1) #(32) x 32 x 32
|
| 238 |
+
self.upc3 = nn.ConvTranspose2d(in_channels=nf, out_channels=nc, kernel_size=(3,3),
|
| 239 |
+
stride=stride_2, padding=1,
|
| 240 |
+
output_padding=output_padding) #(nc) x 64 x 64
|
| 241 |
+
|
| 242 |
+
def forward(self, input):
|
| 243 |
+
d1 = self.upc1(input)
|
| 244 |
+
d2 = self.upc2(d1)
|
| 245 |
+
d3 = self.upc3(d2)
|
| 246 |
+
return d3
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class encoder_specific(nn.Module):
|
| 250 |
+
|
| 251 |
+
def __init__(self, nc=64, nf=64):
|
| 252 |
+
super(encoder_specific, self).__init__()
|
| 253 |
+
self.c1 = dcgan_conv(nc, nf, stride=1) # (64) x 16 x 16
|
| 254 |
+
self.c2 = dcgan_conv(nf, nf, stride=1) # (64) x 16 x 16
|
| 255 |
+
|
| 256 |
+
def forward(self, input):
|
| 257 |
+
h1 = self.c1(input)
|
| 258 |
+
h2 = self.c2(h1)
|
| 259 |
+
return h2
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class decoder_specific(nn.Module):
|
| 263 |
+
|
| 264 |
+
def __init__(self, nc=64, nf=64):
|
| 265 |
+
super(decoder_specific, self).__init__()
|
| 266 |
+
self.upc1 = dcgan_upconv(nf, nf, stride=1) #(64) x 16 x 16
|
| 267 |
+
self.upc2 = dcgan_upconv(nf, nc, stride=1) #(32) x 32 x 32
|
| 268 |
+
|
| 269 |
+
def forward(self, input):
|
| 270 |
+
d1 = self.upc1(input)
|
| 271 |
+
d2 = self.upc2(d1)
|
| 272 |
+
return d2
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class PhyD_EncoderRNN(torch.nn.Module):
|
| 276 |
+
|
| 277 |
+
def __init__(self, phycell, convcell, in_channel=1, patch_size=4):
|
| 278 |
+
super(PhyD_EncoderRNN, self).__init__()
|
| 279 |
+
self.encoder_E = encoder_E(nc=in_channel, patch_size=patch_size) # general encoder 64x64x1 -> 32x32x32
|
| 280 |
+
self.encoder_Ep = encoder_specific() # specific image encoder 32x32x32 -> 16x16x64
|
| 281 |
+
self.encoder_Er = encoder_specific()
|
| 282 |
+
self.decoder_Dp = decoder_specific() # specific image decoder 16x16x64 -> 32x32x32
|
| 283 |
+
self.decoder_Dr = decoder_specific()
|
| 284 |
+
self.decoder_D = decoder_D(nc=in_channel, patch_size=patch_size) # general decoder 32x32x32 -> 64x64x1
|
| 285 |
+
|
| 286 |
+
self.phycell = phycell
|
| 287 |
+
self.convcell = convcell
|
| 288 |
+
|
| 289 |
+
def forward(self, input, first_timestep=False, decoding=False):
|
| 290 |
+
input = self.encoder_E(input) # general encoder 64x64x1 -> 32x32x32
|
| 291 |
+
|
| 292 |
+
if decoding: # input=None in decoding phase
|
| 293 |
+
input_phys = None
|
| 294 |
+
else:
|
| 295 |
+
input_phys = self.encoder_Ep(input)
|
| 296 |
+
input_conv = self.encoder_Er(input)
|
| 297 |
+
|
| 298 |
+
hidden1, output1 = self.phycell(input_phys, first_timestep)
|
| 299 |
+
hidden2, output2 = self.convcell(input_conv, first_timestep)
|
| 300 |
+
|
| 301 |
+
decoded_Dp = self.decoder_Dp(output1[-1])
|
| 302 |
+
decoded_Dr = self.decoder_Dr(output2[-1])
|
| 303 |
+
|
| 304 |
+
out_phys = torch.sigmoid(self.decoder_D(decoded_Dp)) # partial reconstructions for vizualization
|
| 305 |
+
out_conv = torch.sigmoid(self.decoder_D(decoded_Dr))
|
| 306 |
+
|
| 307 |
+
concat = decoded_Dp + decoded_Dr
|
| 308 |
+
output_image = torch.sigmoid( self.decoder_D(concat ))
|
| 309 |
+
return out_phys, hidden1, output_image, out_phys, out_conv
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _apply_axis_left_dot(x, mats):
|
| 313 |
+
assert x.dim() == len(mats)+1
|
| 314 |
+
sizex = x.size()
|
| 315 |
+
k = x.dim()-1
|
| 316 |
+
for i in range(k):
|
| 317 |
+
x = tensordot(mats[k-i-1], x, dim=[1,k])
|
| 318 |
+
x = x.permute([k,]+list(range(k))).contiguous()
|
| 319 |
+
x = x.view(sizex)
|
| 320 |
+
return x
|
| 321 |
+
|
| 322 |
+
def _apply_axis_right_dot(x, mats):
|
| 323 |
+
assert x.dim() == len(mats)+1
|
| 324 |
+
sizex = x.size()
|
| 325 |
+
k = x.dim()-1
|
| 326 |
+
x = x.permute(list(range(1,k+1))+[0,])
|
| 327 |
+
for i in range(k):
|
| 328 |
+
x = tensordot(x, mats[i], dim=[0,0])
|
| 329 |
+
x = x.contiguous()
|
| 330 |
+
x = x.view(sizex)
|
| 331 |
+
return x
|
| 332 |
+
|
| 333 |
+
class _MK(nn.Module):
|
| 334 |
+
def __init__(self, shape):
|
| 335 |
+
super(_MK, self).__init__()
|
| 336 |
+
self._size = torch.Size(shape)
|
| 337 |
+
self._dim = len(shape)
|
| 338 |
+
M = []
|
| 339 |
+
invM = []
|
| 340 |
+
assert len(shape) > 0
|
| 341 |
+
j = 0
|
| 342 |
+
for l in shape:
|
| 343 |
+
M.append(zeros((l,l)))
|
| 344 |
+
for i in range(l):
|
| 345 |
+
M[-1][i] = ((arange(l)-(l-1)//2)**i)/factorial(i)
|
| 346 |
+
invM.append(inv(M[-1]))
|
| 347 |
+
self.register_buffer('_M'+str(j), torch.from_numpy(M[-1]))
|
| 348 |
+
self.register_buffer('_invM'+str(j), torch.from_numpy(invM[-1]))
|
| 349 |
+
j += 1
|
| 350 |
+
|
| 351 |
+
@property
|
| 352 |
+
def M(self):
|
| 353 |
+
return list(self._buffers['_M'+str(j)] for j in range(self.dim()))
|
| 354 |
+
@property
|
| 355 |
+
def invM(self):
|
| 356 |
+
return list(self._buffers['_invM'+str(j)] for j in range(self.dim()))
|
| 357 |
+
|
| 358 |
+
def size(self):
|
| 359 |
+
return self._size
|
| 360 |
+
def dim(self):
|
| 361 |
+
return self._dim
|
| 362 |
+
def _packdim(self, x):
|
| 363 |
+
assert x.dim() >= self.dim()
|
| 364 |
+
if x.dim() == self.dim():
|
| 365 |
+
x = x[newaxis,:]
|
| 366 |
+
x = x.contiguous()
|
| 367 |
+
x = x.view([-1,]+list(x.size()[-self.dim():]))
|
| 368 |
+
return x
|
| 369 |
+
|
| 370 |
+
def forward(self):
|
| 371 |
+
pass
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class M2K(_MK):
|
| 375 |
+
"""
|
| 376 |
+
convert moment matrix to convolution kernel
|
| 377 |
+
Arguments:
|
| 378 |
+
shape (tuple of int): kernel shape
|
| 379 |
+
Usage:
|
| 380 |
+
m2k = M2K([5,5])
|
| 381 |
+
m = torch.randn(5,5,dtype=torch.float64)
|
| 382 |
+
k = m2k(m)
|
| 383 |
+
"""
|
| 384 |
+
def __init__(self, shape):
|
| 385 |
+
super(M2K, self).__init__(shape)
|
| 386 |
+
def forward(self, m):
|
| 387 |
+
"""
|
| 388 |
+
m (Tensor): torch.size=[...,*self.shape]
|
| 389 |
+
"""
|
| 390 |
+
sizem = m.size()
|
| 391 |
+
m = self._packdim(m)
|
| 392 |
+
m = _apply_axis_left_dot(m, self.invM)
|
| 393 |
+
m = m.view(sizem)
|
| 394 |
+
return m
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class K2M(_MK):
|
| 398 |
+
"""
|
| 399 |
+
convert convolution kernel to moment matrix
|
| 400 |
+
Arguments:
|
| 401 |
+
shape (tuple of int): kernel shape
|
| 402 |
+
Usage:
|
| 403 |
+
k2m = K2M([5,5])
|
| 404 |
+
k = torch.randn(5,5,dtype=torch.float64)
|
| 405 |
+
m = k2m(k)
|
| 406 |
+
"""
|
| 407 |
+
def __init__(self, shape):
|
| 408 |
+
super(K2M, self).__init__(shape)
|
| 409 |
+
def forward(self, k):
|
| 410 |
+
"""
|
| 411 |
+
k (Tensor): torch.size=[...,*self.shape]
|
| 412 |
+
"""
|
| 413 |
+
sizek = k.size()
|
| 414 |
+
k = self._packdim(k)
|
| 415 |
+
k = _apply_axis_left_dot(k, self.M)
|
| 416 |
+
k = k.view(sizek)
|
| 417 |
+
return k
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def tensordot(a,b,dim):
|
| 421 |
+
"""
|
| 422 |
+
tensordot in PyTorch, see numpy.tensordot?
|
| 423 |
+
"""
|
| 424 |
+
l = lambda x,y:x*y
|
| 425 |
+
if isinstance(dim,int):
|
| 426 |
+
a = a.contiguous()
|
| 427 |
+
b = b.contiguous()
|
| 428 |
+
sizea = a.size()
|
| 429 |
+
sizeb = b.size()
|
| 430 |
+
sizea0 = sizea[:-dim]
|
| 431 |
+
sizea1 = sizea[-dim:]
|
| 432 |
+
sizeb0 = sizeb[:dim]
|
| 433 |
+
sizeb1 = sizeb[dim:]
|
| 434 |
+
N = reduce(l, sizea1, 1)
|
| 435 |
+
assert reduce(l, sizeb0, 1) == N
|
| 436 |
+
else:
|
| 437 |
+
adims = dim[0]
|
| 438 |
+
bdims = dim[1]
|
| 439 |
+
adims = [adims,] if isinstance(adims, int) else adims
|
| 440 |
+
bdims = [bdims,] if isinstance(bdims, int) else bdims
|
| 441 |
+
adims_ = set(range(a.dim())).difference(set(adims))
|
| 442 |
+
adims_ = list(adims_)
|
| 443 |
+
adims_.sort()
|
| 444 |
+
perma = adims_+adims
|
| 445 |
+
bdims_ = set(range(b.dim())).difference(set(bdims))
|
| 446 |
+
bdims_ = list(bdims_)
|
| 447 |
+
bdims_.sort()
|
| 448 |
+
permb = bdims+bdims_
|
| 449 |
+
a = a.permute(*perma).contiguous()
|
| 450 |
+
b = b.permute(*permb).contiguous()
|
| 451 |
+
|
| 452 |
+
sizea = a.size()
|
| 453 |
+
sizeb = b.size()
|
| 454 |
+
sizea0 = sizea[:-len(adims)]
|
| 455 |
+
sizea1 = sizea[-len(adims):]
|
| 456 |
+
sizeb0 = sizeb[:len(bdims)]
|
| 457 |
+
sizeb1 = sizeb[len(bdims):]
|
| 458 |
+
N = reduce(l, sizea1, 1)
|
| 459 |
+
assert reduce(l, sizeb0, 1) == N
|
| 460 |
+
a = a.view([-1,N])
|
| 461 |
+
b = b.view([N,-1])
|
| 462 |
+
c = a@b
|
| 463 |
+
return c.view(sizea0+sizeb1)
|
utilpack/predrnn_modules.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SpatioTemporalLSTMCell(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
|
| 8 |
+
super(SpatioTemporalLSTMCell, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.num_hidden = num_hidden
|
| 11 |
+
self.padding = filter_size // 2
|
| 12 |
+
self._forget_bias = 1.0
|
| 13 |
+
if layer_norm:
|
| 14 |
+
self.conv_x = nn.Sequential(
|
| 15 |
+
nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
|
| 16 |
+
stride=stride, padding=self.padding, bias=False),
|
| 17 |
+
nn.LayerNorm([num_hidden * 7, height, width])
|
| 18 |
+
)
|
| 19 |
+
self.conv_h = nn.Sequential(
|
| 20 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 21 |
+
stride=stride, padding=self.padding, bias=False),
|
| 22 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 23 |
+
)
|
| 24 |
+
self.conv_m = nn.Sequential(
|
| 25 |
+
nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
|
| 26 |
+
stride=stride, padding=self.padding, bias=False),
|
| 27 |
+
nn.LayerNorm([num_hidden * 3, height, width])
|
| 28 |
+
)
|
| 29 |
+
self.conv_o = nn.Sequential(
|
| 30 |
+
nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
|
| 31 |
+
stride=stride, padding=self.padding, bias=False),
|
| 32 |
+
nn.LayerNorm([num_hidden, height, width])
|
| 33 |
+
)
|
| 34 |
+
else:
|
| 35 |
+
self.conv_x = nn.Sequential(
|
| 36 |
+
nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
|
| 37 |
+
stride=stride, padding=self.padding, bias=False),
|
| 38 |
+
)
|
| 39 |
+
self.conv_h = nn.Sequential(
|
| 40 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 41 |
+
stride=stride, padding=self.padding, bias=False),
|
| 42 |
+
)
|
| 43 |
+
self.conv_m = nn.Sequential(
|
| 44 |
+
nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
|
| 45 |
+
stride=stride, padding=self.padding, bias=False),
|
| 46 |
+
)
|
| 47 |
+
self.conv_o = nn.Sequential(
|
| 48 |
+
nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
|
| 49 |
+
stride=stride, padding=self.padding, bias=False),
|
| 50 |
+
)
|
| 51 |
+
self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
|
| 52 |
+
stride=1, padding=0, bias=False)
|
| 53 |
+
|
| 54 |
+
def forward(self, x_t, h_t, c_t, m_t):
|
| 55 |
+
x_concat = self.conv_x(x_t)
|
| 56 |
+
h_concat = self.conv_h(h_t)
|
| 57 |
+
m_concat = self.conv_m(m_t)
|
| 58 |
+
i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = \
|
| 59 |
+
torch.split(x_concat, self.num_hidden, dim=1)
|
| 60 |
+
i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
|
| 61 |
+
i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)
|
| 62 |
+
|
| 63 |
+
i_t = torch.sigmoid(i_x + i_h)
|
| 64 |
+
f_t = torch.sigmoid(f_x + f_h + self._forget_bias)
|
| 65 |
+
g_t = torch.tanh(g_x + g_h)
|
| 66 |
+
|
| 67 |
+
c_new = f_t * c_t + i_t * g_t
|
| 68 |
+
|
| 69 |
+
i_t_prime = torch.sigmoid(i_x_prime + i_m)
|
| 70 |
+
f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)
|
| 71 |
+
g_t_prime = torch.tanh(g_x_prime + g_m)
|
| 72 |
+
|
| 73 |
+
m_new = f_t_prime * m_t + i_t_prime * g_t_prime
|
| 74 |
+
|
| 75 |
+
mem = torch.cat((c_new, m_new), 1)
|
| 76 |
+
o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))
|
| 77 |
+
h_new = o_t * torch.tanh(self.conv_last(mem))
|
| 78 |
+
|
| 79 |
+
return h_new, c_new, m_new
|
utilpack/predrnnpp_modules.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class CausalLSTMCell(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
|
| 8 |
+
super(CausalLSTMCell, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.num_hidden = num_hidden
|
| 11 |
+
self.padding = filter_size // 2
|
| 12 |
+
self._forget_bias = 1.0
|
| 13 |
+
if layer_norm:
|
| 14 |
+
self.conv_x = nn.Sequential(
|
| 15 |
+
nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
|
| 16 |
+
stride=stride, padding=self.padding, bias=False),
|
| 17 |
+
nn.LayerNorm([num_hidden * 7, height, width])
|
| 18 |
+
)
|
| 19 |
+
self.conv_h = nn.Sequential(
|
| 20 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 21 |
+
stride=stride, padding=self.padding, bias=False),
|
| 22 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 23 |
+
)
|
| 24 |
+
self.conv_c = nn.Sequential(
|
| 25 |
+
nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
|
| 26 |
+
stride=stride, padding=self.padding, bias=False),
|
| 27 |
+
nn.LayerNorm([num_hidden * 3, height, width])
|
| 28 |
+
)
|
| 29 |
+
self.conv_m = nn.Sequential(
|
| 30 |
+
nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
|
| 31 |
+
stride=stride, padding=self.padding, bias=False),
|
| 32 |
+
nn.LayerNorm([num_hidden * 3, height, width])
|
| 33 |
+
)
|
| 34 |
+
self.conv_o = nn.Sequential(
|
| 35 |
+
nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
|
| 36 |
+
stride=stride, padding=self.padding, bias=False),
|
| 37 |
+
nn.LayerNorm([num_hidden, height, width])
|
| 38 |
+
)
|
| 39 |
+
self.conv_c2m = nn.Sequential(
|
| 40 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 41 |
+
stride=stride, padding=self.padding, bias=False),
|
| 42 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 43 |
+
)
|
| 44 |
+
self.conv_om = nn.Sequential(
|
| 45 |
+
nn.Conv2d(num_hidden, num_hidden, kernel_size=filter_size,
|
| 46 |
+
stride=stride, padding=self.padding, bias=False),
|
| 47 |
+
nn.LayerNorm([num_hidden, height, width])
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
self.conv_x = nn.Sequential(
|
| 51 |
+
nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
|
| 52 |
+
stride=stride, padding=self.padding, bias=False),
|
| 53 |
+
)
|
| 54 |
+
self.conv_h = nn.Sequential(
|
| 55 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 56 |
+
stride=stride, padding=self.padding, bias=False),
|
| 57 |
+
)
|
| 58 |
+
self.conv_c = nn.Sequential(
|
| 59 |
+
nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
|
| 60 |
+
stride=stride, padding=self.padding, bias=False),
|
| 61 |
+
)
|
| 62 |
+
self.conv_m = nn.Sequential(
|
| 63 |
+
nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
|
| 64 |
+
stride=stride, padding=self.padding, bias=False),
|
| 65 |
+
)
|
| 66 |
+
self.conv_o = nn.Sequential(
|
| 67 |
+
nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
|
| 68 |
+
stride=stride, padding=self.padding, bias=False),
|
| 69 |
+
)
|
| 70 |
+
self.conv_c2m = nn.Sequential(
|
| 71 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 72 |
+
stride=stride, padding=self.padding, bias=False),
|
| 73 |
+
)
|
| 74 |
+
self.conv_om = nn.Sequential(
|
| 75 |
+
nn.Conv2d(num_hidden, num_hidden, kernel_size=filter_size,
|
| 76 |
+
stride=stride, padding=self.padding, bias=False),
|
| 77 |
+
)
|
| 78 |
+
self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
|
| 79 |
+
stride=1, padding=0, bias=False)
|
| 80 |
+
|
| 81 |
+
def forward(self, x_t, h_t, c_t, m_t):
|
| 82 |
+
x_concat = self.conv_x(x_t)
|
| 83 |
+
h_concat = self.conv_h(h_t)
|
| 84 |
+
c_concat = self.conv_c(c_t)
|
| 85 |
+
m_concat = self.conv_m(m_t)
|
| 86 |
+
i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = \
|
| 87 |
+
torch.split(x_concat, self.num_hidden, dim=1)
|
| 88 |
+
i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
|
| 89 |
+
i_m, f_m, m_m = torch.split(m_concat, self.num_hidden, dim=1)
|
| 90 |
+
i_c, f_c, g_c = torch.split(c_concat, self.num_hidden, dim=1)
|
| 91 |
+
|
| 92 |
+
i_t = torch.sigmoid(i_x + i_h + i_c)
|
| 93 |
+
f_t = torch.sigmoid(f_x + f_h + f_c + self._forget_bias)
|
| 94 |
+
g_t = torch.tanh(g_x + g_h + g_c)
|
| 95 |
+
|
| 96 |
+
c_new = f_t * c_t + i_t * g_t
|
| 97 |
+
|
| 98 |
+
c2m = self.conv_c2m(c_new)
|
| 99 |
+
i_c, g_c, f_c, o_c = torch.split(c2m, self.num_hidden, dim=1)
|
| 100 |
+
|
| 101 |
+
i_t_prime = torch.sigmoid(i_x_prime + i_m + i_c)
|
| 102 |
+
f_t_prime = torch.sigmoid(f_x_prime + f_m + f_c + self._forget_bias)
|
| 103 |
+
g_t_prime = torch.tanh(g_x_prime + g_c)
|
| 104 |
+
|
| 105 |
+
m_new = f_t_prime * torch.tanh(m_m) + i_t_prime * g_t_prime
|
| 106 |
+
o_m = self.conv_om(m_new)
|
| 107 |
+
|
| 108 |
+
o_t = torch.tanh(o_x + o_h + o_c + o_m)
|
| 109 |
+
mem = torch.cat((c_new, m_new), 1)
|
| 110 |
+
h_new = o_t * torch.tanh(self.conv_last(mem))
|
| 111 |
+
|
| 112 |
+
return h_new, c_new, m_new
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class GHU(nn.Module):
|
| 116 |
+
def __init__(self, in_channel, num_hidden, height, width, filter_size,
|
| 117 |
+
stride, layer_norm, initializer=0.001):
|
| 118 |
+
super(GHU, self).__init__()
|
| 119 |
+
|
| 120 |
+
self.filter_size = filter_size
|
| 121 |
+
self.padding = filter_size // 2
|
| 122 |
+
self.num_hidden = num_hidden
|
| 123 |
+
self.layer_norm = layer_norm
|
| 124 |
+
|
| 125 |
+
if layer_norm:
|
| 126 |
+
self.z_concat = nn.Sequential(
|
| 127 |
+
nn.Conv2d(in_channel, num_hidden * 2, kernel_size=filter_size,
|
| 128 |
+
stride=stride, padding=self.padding, bias=False),
|
| 129 |
+
nn.LayerNorm([num_hidden, height, width])
|
| 130 |
+
)
|
| 131 |
+
self.x_concat = nn.Sequential(
|
| 132 |
+
nn.Conv2d(in_channel, num_hidden * 2, kernel_size=filter_size,
|
| 133 |
+
stride=stride, padding=self.padding, bias=False),
|
| 134 |
+
nn.LayerNorm([num_hidden, height, width])
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
self.z_concat = nn.Sequential(
|
| 138 |
+
nn.Conv2d(in_channel, num_hidden * 2, kernel_size=filter_size,
|
| 139 |
+
stride=stride, padding=self.padding, bias=False),
|
| 140 |
+
)
|
| 141 |
+
self.x_concat = nn.Sequential(
|
| 142 |
+
nn.Conv2d(in_channel, num_hidden * 2, kernel_size=filter_size,
|
| 143 |
+
stride=stride, padding=self.padding, bias=False),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if initializer != -1:
|
| 148 |
+
self.initializer = initializer
|
| 149 |
+
self.apply(self._init_weights)
|
| 150 |
+
|
| 151 |
+
def _init_weights(self, m):
|
| 152 |
+
if isinstance(m, (nn.Conv2d)):
|
| 153 |
+
nn.init.uniform_(m.weight, -self.initializer, self.initializer)
|
| 154 |
+
|
| 155 |
+
def _init_state(self, inputs):
|
| 156 |
+
return torch.zeros_like(inputs)
|
| 157 |
+
|
| 158 |
+
def forward(self, x, z):
|
| 159 |
+
if z is None:
|
| 160 |
+
z = self._init_state(x)
|
| 161 |
+
z_concat = self.z_concat(z)
|
| 162 |
+
x_concat = self.x_concat(x)
|
| 163 |
+
|
| 164 |
+
gates = x_concat + z_concat
|
| 165 |
+
p, u = torch.split(gates, self.num_hidden, dim=1)
|
| 166 |
+
p = torch.tanh(p)
|
| 167 |
+
u = torch.sigmoid(u)
|
| 168 |
+
z_new = u * p + (1-u) * z
|
| 169 |
+
return z_new
|
utilpack/predrnnv2_modules.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SpatioTemporalLSTMCellv2(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
|
| 8 |
+
super(SpatioTemporalLSTMCellv2, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.num_hidden = num_hidden
|
| 11 |
+
self.padding = filter_size // 2
|
| 12 |
+
self._forget_bias = 1.0
|
| 13 |
+
if layer_norm:
|
| 14 |
+
self.conv_x = nn.Sequential(
|
| 15 |
+
nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
|
| 16 |
+
stride=stride, padding=self.padding, bias=False),
|
| 17 |
+
nn.LayerNorm([num_hidden * 7, height, width])
|
| 18 |
+
)
|
| 19 |
+
self.conv_h = nn.Sequential(
|
| 20 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 21 |
+
stride=stride, padding=self.padding, bias=False),
|
| 22 |
+
nn.LayerNorm([num_hidden * 4, height, width])
|
| 23 |
+
)
|
| 24 |
+
self.conv_m = nn.Sequential(
|
| 25 |
+
nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
|
| 26 |
+
stride=stride, padding=self.padding, bias=False),
|
| 27 |
+
nn.LayerNorm([num_hidden * 3, height, width])
|
| 28 |
+
)
|
| 29 |
+
self.conv_o = nn.Sequential(
|
| 30 |
+
nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
|
| 31 |
+
stride=stride, padding=self.padding, bias=False),
|
| 32 |
+
nn.LayerNorm([num_hidden, height, width])
|
| 33 |
+
)
|
| 34 |
+
else:
|
| 35 |
+
self.conv_x = nn.Sequential(
|
| 36 |
+
nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
|
| 37 |
+
stride=stride, padding=self.padding, bias=False),
|
| 38 |
+
)
|
| 39 |
+
self.conv_h = nn.Sequential(
|
| 40 |
+
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
|
| 41 |
+
stride=stride, padding=self.padding, bias=False),
|
| 42 |
+
)
|
| 43 |
+
self.conv_m = nn.Sequential(
|
| 44 |
+
nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
|
| 45 |
+
stride=stride, padding=self.padding, bias=False),
|
| 46 |
+
)
|
| 47 |
+
self.conv_o = nn.Sequential(
|
| 48 |
+
nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
|
| 49 |
+
stride=stride, padding=self.padding, bias=False),
|
| 50 |
+
)
|
| 51 |
+
self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
|
| 52 |
+
stride=1, padding=0, bias=False)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def forward(self, x_t, h_t, c_t, m_t):
|
| 56 |
+
x_concat = self.conv_x(x_t)
|
| 57 |
+
h_concat = self.conv_h(h_t)
|
| 58 |
+
m_concat = self.conv_m(m_t)
|
| 59 |
+
i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = \
|
| 60 |
+
torch.split(x_concat, self.num_hidden, dim=1)
|
| 61 |
+
i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
|
| 62 |
+
i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)
|
| 63 |
+
|
| 64 |
+
i_t = torch.sigmoid(i_x + i_h)
|
| 65 |
+
f_t = torch.sigmoid(f_x + f_h + self._forget_bias)
|
| 66 |
+
g_t = torch.tanh(g_x + g_h)
|
| 67 |
+
|
| 68 |
+
delta_c = i_t * g_t
|
| 69 |
+
c_new = f_t * c_t + delta_c
|
| 70 |
+
|
| 71 |
+
i_t_prime = torch.sigmoid(i_x_prime + i_m)
|
| 72 |
+
f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)
|
| 73 |
+
g_t_prime = torch.tanh(g_x_prime + g_m)
|
| 74 |
+
|
| 75 |
+
delta_m = i_t_prime * g_t_prime
|
| 76 |
+
m_new = f_t_prime * m_t + delta_m
|
| 77 |
+
|
| 78 |
+
mem = torch.cat((c_new, m_new), 1)
|
| 79 |
+
o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))
|
| 80 |
+
h_new = o_t * torch.tanh(self.conv_last(mem))
|
| 81 |
+
|
| 82 |
+
return h_new, c_new, m_new, delta_c, delta_m
|
utilpack/simvp_modules.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from timm.layers import DropPath, trunc_normal_
|
| 6 |
+
from timm.models.convnext import ConvNeXtBlock
|
| 7 |
+
from timm.models.mlp_mixer import MixerBlock
|
| 8 |
+
from timm.models.swin_transformer import SwinTransformerBlock, window_partition, window_reverse
|
| 9 |
+
from timm.models.vision_transformer import Block as ViTBlock
|
| 10 |
+
|
| 11 |
+
from .layers import (HorBlock, ChannelAggregationFFN, MultiOrderGatedAggregation,
|
| 12 |
+
PoolFormerBlock, CBlock, SABlock, MixMlp, VANBlock)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BasicConv2d(nn.Module):
|
| 16 |
+
|
| 17 |
+
def __init__(self,
|
| 18 |
+
in_channels,
|
| 19 |
+
out_channels,
|
| 20 |
+
kernel_size=3,
|
| 21 |
+
stride=1,
|
| 22 |
+
padding=0,
|
| 23 |
+
dilation=1,
|
| 24 |
+
upsampling=False,
|
| 25 |
+
act_norm=False,
|
| 26 |
+
act_inplace=True):
|
| 27 |
+
super(BasicConv2d, self).__init__()
|
| 28 |
+
self.act_norm = act_norm
|
| 29 |
+
if upsampling is True:
|
| 30 |
+
self.conv = nn.Sequential(*[
|
| 31 |
+
nn.Conv2d(in_channels, out_channels*4, kernel_size=kernel_size,
|
| 32 |
+
stride=1, padding=padding, dilation=dilation),
|
| 33 |
+
nn.PixelShuffle(2)
|
| 34 |
+
])
|
| 35 |
+
else:
|
| 36 |
+
self.conv = nn.Conv2d(
|
| 37 |
+
in_channels, out_channels, kernel_size=kernel_size,
|
| 38 |
+
stride=stride, padding=padding, dilation=dilation)
|
| 39 |
+
|
| 40 |
+
self.norm = nn.GroupNorm(2, out_channels)
|
| 41 |
+
self.act = nn.SiLU(inplace=act_inplace)
|
| 42 |
+
|
| 43 |
+
self.apply(self._init_weights)
|
| 44 |
+
|
| 45 |
+
def _init_weights(self, m):
|
| 46 |
+
if isinstance(m, (nn.Conv2d)):
|
| 47 |
+
trunc_normal_(m.weight, std=.02)
|
| 48 |
+
nn.init.constant_(m.bias, 0)
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
y = self.conv(x)
|
| 52 |
+
if self.act_norm:
|
| 53 |
+
y = self.act(self.norm(y))
|
| 54 |
+
return y
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ConvSC(nn.Module):
|
| 58 |
+
|
| 59 |
+
def __init__(self,
|
| 60 |
+
C_in,
|
| 61 |
+
C_out,
|
| 62 |
+
kernel_size=3,
|
| 63 |
+
downsampling=False,
|
| 64 |
+
upsampling=False,
|
| 65 |
+
act_norm=True,
|
| 66 |
+
act_inplace=True):
|
| 67 |
+
super(ConvSC, self).__init__()
|
| 68 |
+
|
| 69 |
+
stride = 2 if downsampling is True else 1
|
| 70 |
+
padding = (kernel_size - stride + 1) // 2
|
| 71 |
+
|
| 72 |
+
self.conv = BasicConv2d(C_in, C_out, kernel_size=kernel_size, stride=stride,
|
| 73 |
+
upsampling=upsampling, padding=padding,
|
| 74 |
+
act_norm=act_norm, act_inplace=act_inplace)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
y = self.conv(x)
|
| 78 |
+
return y
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class GroupConv2d(nn.Module):
|
| 82 |
+
|
| 83 |
+
def __init__(self,
|
| 84 |
+
in_channels,
|
| 85 |
+
out_channels,
|
| 86 |
+
kernel_size=3,
|
| 87 |
+
stride=1,
|
| 88 |
+
padding=0,
|
| 89 |
+
groups=1,
|
| 90 |
+
act_norm=False,
|
| 91 |
+
act_inplace=True):
|
| 92 |
+
super(GroupConv2d, self).__init__()
|
| 93 |
+
self.act_norm=act_norm
|
| 94 |
+
if in_channels % groups != 0:
|
| 95 |
+
groups=1
|
| 96 |
+
self.conv = nn.Conv2d(
|
| 97 |
+
in_channels, out_channels, kernel_size=kernel_size,
|
| 98 |
+
stride=stride, padding=padding, groups=groups)
|
| 99 |
+
self.norm = nn.GroupNorm(groups,out_channels)
|
| 100 |
+
self.activate = nn.LeakyReLU(0.2, inplace=act_inplace)
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
y = self.conv(x)
|
| 104 |
+
if self.act_norm:
|
| 105 |
+
y = self.activate(self.norm(y))
|
| 106 |
+
return y
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class gInception_ST(nn.Module):
|
| 110 |
+
"""A IncepU block for SimVP"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, C_in, C_hid, C_out, incep_ker = [3,5,7,11], groups = 8):
|
| 113 |
+
super(gInception_ST, self).__init__()
|
| 114 |
+
self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
|
| 115 |
+
|
| 116 |
+
layers = []
|
| 117 |
+
for ker in incep_ker:
|
| 118 |
+
layers.append(GroupConv2d(
|
| 119 |
+
C_hid, C_out, kernel_size=ker, stride=1,
|
| 120 |
+
padding=ker//2, groups=groups, act_norm=True))
|
| 121 |
+
self.layers = nn.Sequential(*layers)
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
x = self.conv1(x)
|
| 125 |
+
y = 0
|
| 126 |
+
for layer in self.layers:
|
| 127 |
+
y += layer(x)
|
| 128 |
+
return y
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class AttentionModule(nn.Module):
|
| 132 |
+
"""Large Kernel Attention for SimVP"""
|
| 133 |
+
|
| 134 |
+
def __init__(self, dim, kernel_size, dilation=3):
|
| 135 |
+
super().__init__()
|
| 136 |
+
d_k = 2 * dilation - 1
|
| 137 |
+
d_p = (d_k - 1) // 2
|
| 138 |
+
dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1)
|
| 139 |
+
dd_p = (dilation * (dd_k - 1) // 2)
|
| 140 |
+
|
| 141 |
+
self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim)
|
| 142 |
+
self.conv_spatial = nn.Conv2d(
|
| 143 |
+
dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation)
|
| 144 |
+
self.conv1 = nn.Conv2d(dim, 2*dim, 1)
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
u = x.clone()
|
| 148 |
+
attn = self.conv0(x) # depth-wise conv
|
| 149 |
+
attn = self.conv_spatial(attn) # depth-wise dilation convolution
|
| 150 |
+
|
| 151 |
+
f_g = self.conv1(attn)
|
| 152 |
+
split_dim = f_g.shape[1] // 2
|
| 153 |
+
f_x, g_x = torch.split(f_g, split_dim, dim=1)
|
| 154 |
+
return torch.sigmoid(g_x) * f_x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class SpatialAttention(nn.Module):
|
| 158 |
+
"""A Spatial Attention block for SimVP"""
|
| 159 |
+
|
| 160 |
+
def __init__(self, d_model, kernel_size=21, attn_shortcut=True):
|
| 161 |
+
super().__init__()
|
| 162 |
+
|
| 163 |
+
self.proj_1 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
|
| 164 |
+
self.activation = nn.GELU() # GELU
|
| 165 |
+
self.spatial_gating_unit = AttentionModule(d_model, kernel_size)
|
| 166 |
+
self.proj_2 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
|
| 167 |
+
self.attn_shortcut = attn_shortcut
|
| 168 |
+
|
| 169 |
+
def forward(self, x):
|
| 170 |
+
if self.attn_shortcut:
|
| 171 |
+
shortcut = x.clone()
|
| 172 |
+
x = self.proj_1(x)
|
| 173 |
+
x = self.activation(x)
|
| 174 |
+
x = self.spatial_gating_unit(x)
|
| 175 |
+
x = self.proj_2(x)
|
| 176 |
+
if self.attn_shortcut:
|
| 177 |
+
x = x + shortcut
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class GASubBlock(nn.Module):
|
| 182 |
+
"""A GABlock (gSTA) for SimVP"""
|
| 183 |
+
|
| 184 |
+
def __init__(self, dim, kernel_size=21, mlp_ratio=4.,
|
| 185 |
+
drop=0., drop_path=0.1, init_value=1e-2, act_layer=nn.GELU):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.norm1 = nn.BatchNorm2d(dim)
|
| 188 |
+
self.attn = SpatialAttention(dim, kernel_size)
|
| 189 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 190 |
+
|
| 191 |
+
self.norm2 = nn.BatchNorm2d(dim)
|
| 192 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 193 |
+
self.mlp = MixMlp(
|
| 194 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 195 |
+
|
| 196 |
+
self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
|
| 197 |
+
self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
|
| 198 |
+
|
| 199 |
+
self.apply(self._init_weights)
|
| 200 |
+
|
| 201 |
+
def _init_weights(self, m):
|
| 202 |
+
if isinstance(m, nn.Linear):
|
| 203 |
+
trunc_normal_(m.weight, std=.02)
|
| 204 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 205 |
+
nn.init.constant_(m.bias, 0)
|
| 206 |
+
elif isinstance(m, nn.LayerNorm):
|
| 207 |
+
nn.init.constant_(m.bias, 0)
|
| 208 |
+
nn.init.constant_(m.weight, 1.0)
|
| 209 |
+
elif isinstance(m, nn.Conv2d):
|
| 210 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 211 |
+
fan_out //= m.groups
|
| 212 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 213 |
+
if m.bias is not None:
|
| 214 |
+
m.bias.data.zero_()
|
| 215 |
+
|
| 216 |
+
@torch.jit.ignore
|
| 217 |
+
def no_weight_decay(self):
|
| 218 |
+
return {'layer_scale_1', 'layer_scale_2'}
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
x = x + self.drop_path(
|
| 222 |
+
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
|
| 223 |
+
x = x + self.drop_path(
|
| 224 |
+
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class ConvMixerSubBlock(nn.Module):
|
| 229 |
+
"""A block of ConvMixer."""
|
| 230 |
+
|
| 231 |
+
def __init__(self, dim, kernel_size=9, activation=nn.GELU):
|
| 232 |
+
super().__init__()
|
| 233 |
+
# spatial mixing
|
| 234 |
+
self.conv_dw = nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same")
|
| 235 |
+
self.act_1 = activation()
|
| 236 |
+
self.norm_1 = nn.BatchNorm2d(dim)
|
| 237 |
+
# channel mixing
|
| 238 |
+
self.conv_pw = nn.Conv2d(dim, dim, kernel_size=1)
|
| 239 |
+
self.act_2 = activation()
|
| 240 |
+
self.norm_2 = nn.BatchNorm2d(dim)
|
| 241 |
+
|
| 242 |
+
self.apply(self._init_weights)
|
| 243 |
+
|
| 244 |
+
def _init_weights(self, m):
|
| 245 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 246 |
+
nn.init.constant_(m.bias, 0)
|
| 247 |
+
nn.init.constant_(m.weight, 1.0)
|
| 248 |
+
elif isinstance(m, nn.Conv2d):
|
| 249 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 250 |
+
fan_out //= m.groups
|
| 251 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 252 |
+
if m.bias is not None:
|
| 253 |
+
m.bias.data.zero_()
|
| 254 |
+
|
| 255 |
+
@torch.jit.ignore
|
| 256 |
+
def no_weight_decay(self):
|
| 257 |
+
return dict()
|
| 258 |
+
|
| 259 |
+
def forward(self, x):
|
| 260 |
+
x = x + self.norm_1(self.act_1(self.conv_dw(x)))
|
| 261 |
+
x = self.norm_2(self.act_2(self.conv_pw(x)))
|
| 262 |
+
return x
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class ConvNeXtSubBlock(ConvNeXtBlock):
|
| 266 |
+
"""A block of ConvNeXt."""
|
| 267 |
+
|
| 268 |
+
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1):
|
| 269 |
+
super().__init__(dim, mlp_ratio=mlp_ratio,
|
| 270 |
+
drop_path=drop_path, ls_init_value=1e-6, conv_mlp=True)
|
| 271 |
+
self.apply(self._init_weights)
|
| 272 |
+
|
| 273 |
+
def _init_weights(self, m):
|
| 274 |
+
if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 275 |
+
nn.init.constant_(m.bias, 0)
|
| 276 |
+
nn.init.constant_(m.weight, 1.0)
|
| 277 |
+
elif isinstance(m, nn.Conv2d):
|
| 278 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 279 |
+
fan_out //= m.groups
|
| 280 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 281 |
+
if m.bias is not None:
|
| 282 |
+
m.bias.data.zero_()
|
| 283 |
+
|
| 284 |
+
@torch.jit.ignore
|
| 285 |
+
def no_weight_decay(self):
|
| 286 |
+
return {'gamma'}
|
| 287 |
+
|
| 288 |
+
def forward(self, x):
|
| 289 |
+
x = x + self.drop_path(
|
| 290 |
+
self.gamma.reshape(1, -1, 1, 1) * self.mlp(self.norm(self.conv_dw(x))))
|
| 291 |
+
return x
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class HorNetSubBlock(HorBlock):
|
| 295 |
+
"""A block of HorNet."""
|
| 296 |
+
|
| 297 |
+
def __init__(self, dim, mlp_ratio=4., drop_path=0.1, init_value=1e-6):
|
| 298 |
+
super().__init__(dim, mlp_ratio=mlp_ratio, drop_path=drop_path, init_value=init_value)
|
| 299 |
+
self.apply(self._init_weights)
|
| 300 |
+
|
| 301 |
+
@torch.jit.ignore
|
| 302 |
+
def no_weight_decay(self):
|
| 303 |
+
return {'gamma1', 'gamma2'}
|
| 304 |
+
|
| 305 |
+
def _init_weights(self, m):
|
| 306 |
+
if isinstance(m, nn.Linear):
|
| 307 |
+
trunc_normal_(m.weight, std=.02)
|
| 308 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 309 |
+
nn.init.constant_(m.bias, 0)
|
| 310 |
+
elif isinstance(m, nn.LayerNorm):
|
| 311 |
+
nn.init.constant_(m.bias, 0)
|
| 312 |
+
nn.init.constant_(m.weight, 1.0)
|
| 313 |
+
elif isinstance(m, nn.Conv2d):
|
| 314 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 315 |
+
fan_out //= m.groups
|
| 316 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 317 |
+
if m.bias is not None:
|
| 318 |
+
m.bias.data.zero_()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class MLPMixerSubBlock(MixerBlock):
|
| 322 |
+
"""A block of MLP-Mixer."""
|
| 323 |
+
|
| 324 |
+
def __init__(self, dim, input_resolution=None, mlp_ratio=4., drop=0., drop_path=0.1):
|
| 325 |
+
seq_len = input_resolution[0] * input_resolution[1]
|
| 326 |
+
super().__init__(dim, seq_len=seq_len,
|
| 327 |
+
mlp_ratio=(0.5, mlp_ratio), drop_path=drop_path, drop=drop)
|
| 328 |
+
self.apply(self._init_weights)
|
| 329 |
+
|
| 330 |
+
def _init_weights(self, m):
|
| 331 |
+
if isinstance(m, nn.Linear):
|
| 332 |
+
trunc_normal_(m.weight, std=.02)
|
| 333 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 334 |
+
nn.init.constant_(m.bias, 0)
|
| 335 |
+
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 336 |
+
nn.init.constant_(m.bias, 0)
|
| 337 |
+
nn.init.constant_(m.weight, 1.0)
|
| 338 |
+
|
| 339 |
+
@torch.jit.ignore
|
| 340 |
+
def no_weight_decay(self):
|
| 341 |
+
return dict()
|
| 342 |
+
|
| 343 |
+
def forward(self, x):
|
| 344 |
+
B, C, H, W = x.shape
|
| 345 |
+
x = x.flatten(2).transpose(1, 2)
|
| 346 |
+
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
|
| 347 |
+
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
|
| 348 |
+
return x.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class MogaSubBlock(nn.Module):
|
| 352 |
+
"""A block of MogaNet."""
|
| 353 |
+
|
| 354 |
+
def __init__(self, embed_dims, mlp_ratio=4., drop_rate=0., drop_path_rate=0., init_value=1e-5,
|
| 355 |
+
attn_dw_dilation=[1, 2, 3], attn_channel_split=[1, 3, 4]):
|
| 356 |
+
super(MogaSubBlock, self).__init__()
|
| 357 |
+
self.out_channels = embed_dims
|
| 358 |
+
# spatial attention
|
| 359 |
+
self.norm1 = nn.BatchNorm2d(embed_dims)
|
| 360 |
+
self.attn = MultiOrderGatedAggregation(
|
| 361 |
+
embed_dims, attn_dw_dilation=attn_dw_dilation, attn_channel_split=attn_channel_split)
|
| 362 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 363 |
+
# channel MLP
|
| 364 |
+
self.norm2 = nn.BatchNorm2d(embed_dims)
|
| 365 |
+
mlp_hidden_dims = int(embed_dims * mlp_ratio)
|
| 366 |
+
self.mlp = ChannelAggregationFFN(
|
| 367 |
+
embed_dims=embed_dims, mlp_hidden_dims=mlp_hidden_dims, ffn_drop=drop_rate)
|
| 368 |
+
# init layer scale
|
| 369 |
+
self.layer_scale_1 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
|
| 370 |
+
self.layer_scale_2 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
|
| 371 |
+
|
| 372 |
+
self.apply(self._init_weights)
|
| 373 |
+
|
| 374 |
+
def _init_weights(self, m):
|
| 375 |
+
if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 376 |
+
nn.init.constant_(m.bias, 0)
|
| 377 |
+
nn.init.constant_(m.weight, 1.0)
|
| 378 |
+
elif isinstance(m, nn.Conv2d):
|
| 379 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 380 |
+
fan_out //= m.groups
|
| 381 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 382 |
+
if m.bias is not None:
|
| 383 |
+
m.bias.data.zero_()
|
| 384 |
+
|
| 385 |
+
@torch.jit.ignore
|
| 386 |
+
def no_weight_decay(self):
|
| 387 |
+
return {'layer_scale_1', 'layer_scale_2', 'sigma'}
|
| 388 |
+
|
| 389 |
+
def forward(self, x):
|
| 390 |
+
x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x)))
|
| 391 |
+
x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x)))
|
| 392 |
+
return x
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class PoolFormerSubBlock(PoolFormerBlock):
|
| 396 |
+
"""A block of PoolFormer."""
|
| 397 |
+
|
| 398 |
+
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1):
|
| 399 |
+
super().__init__(dim, pool_size=3, mlp_ratio=mlp_ratio, drop_path=drop_path,
|
| 400 |
+
drop=drop, init_value=1e-5)
|
| 401 |
+
self.apply(self._init_weights)
|
| 402 |
+
|
| 403 |
+
@torch.jit.ignore
|
| 404 |
+
def no_weight_decay(self):
|
| 405 |
+
return {'layer_scale_1', 'layer_scale_2'}
|
| 406 |
+
|
| 407 |
+
def _init_weights(self, m):
|
| 408 |
+
if isinstance(m, nn.Linear):
|
| 409 |
+
trunc_normal_(m.weight, std=.02)
|
| 410 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 411 |
+
nn.init.constant_(m.bias, 0)
|
| 412 |
+
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 413 |
+
nn.init.constant_(m.bias, 0)
|
| 414 |
+
nn.init.constant_(m.weight, 1.0)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class SwinSubBlock(SwinTransformerBlock):
|
| 418 |
+
"""A block of Swin Transformer."""
|
| 419 |
+
|
| 420 |
+
def __init__(self, dim, input_resolution=None, layer_i=0, mlp_ratio=4., drop=0., drop_path=0.1):
|
| 421 |
+
window_size = 7 if input_resolution[0] % 7 == 0 else max(4, input_resolution[0] // 16)
|
| 422 |
+
window_size = min(8, window_size)
|
| 423 |
+
shift_size = 0 if (layer_i % 2 == 0) else window_size // 2
|
| 424 |
+
super().__init__(dim, input_resolution, num_heads=8, window_size=window_size,
|
| 425 |
+
shift_size=shift_size, mlp_ratio=mlp_ratio,
|
| 426 |
+
drop_path=drop_path, attn_drop=drop, proj_drop=drop, qkv_bias=True)
|
| 427 |
+
self.apply(self._init_weights)
|
| 428 |
+
|
| 429 |
+
def _init_weights(self, m):
|
| 430 |
+
if isinstance(m, nn.Linear):
|
| 431 |
+
trunc_normal_(m.weight, std=.02)
|
| 432 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 433 |
+
nn.init.constant_(m.bias, 0)
|
| 434 |
+
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 435 |
+
nn.init.constant_(m.bias, 0)
|
| 436 |
+
nn.init.constant_(m.weight, 1.0)
|
| 437 |
+
|
| 438 |
+
@torch.jit.ignore
|
| 439 |
+
def no_weight_decay(self):
|
| 440 |
+
return {}
|
| 441 |
+
|
| 442 |
+
def forward(self, x):
|
| 443 |
+
B, C, H, W = x.shape
|
| 444 |
+
x = x.flatten(2).transpose(1, 2)
|
| 445 |
+
x = self.norm1(x)
|
| 446 |
+
x = x.view(B, H, W, C)
|
| 447 |
+
x = super().forward(x)
|
| 448 |
+
|
| 449 |
+
return x.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def UniformerSubBlock(embed_dims, mlp_ratio=4., drop=0., drop_path=0.,
|
| 453 |
+
init_value=1e-6, block_type='Conv'):
|
| 454 |
+
"""Build a block of Uniformer."""
|
| 455 |
+
|
| 456 |
+
assert block_type in ['Conv', 'MHSA']
|
| 457 |
+
if block_type == 'Conv':
|
| 458 |
+
return CBlock(dim=embed_dims, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
|
| 459 |
+
else:
|
| 460 |
+
return SABlock(dim=embed_dims, num_heads=8, mlp_ratio=mlp_ratio, qkv_bias=True,
|
| 461 |
+
drop=drop, drop_path=drop_path, init_value=init_value)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class VANSubBlock(VANBlock):
|
| 465 |
+
"""A block of VAN."""
|
| 466 |
+
|
| 467 |
+
def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., init_value=1e-2, act_layer=nn.GELU):
|
| 468 |
+
super().__init__(dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path,
|
| 469 |
+
init_value=init_value, act_layer=act_layer)
|
| 470 |
+
self.apply(self._init_weights)
|
| 471 |
+
|
| 472 |
+
@torch.jit.ignore
|
| 473 |
+
def no_weight_decay(self):
|
| 474 |
+
return {'layer_scale_1', 'layer_scale_2'}
|
| 475 |
+
|
| 476 |
+
def _init_weights(self, m):
|
| 477 |
+
if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 478 |
+
nn.init.constant_(m.bias, 0)
|
| 479 |
+
nn.init.constant_(m.weight, 1.0)
|
| 480 |
+
elif isinstance(m, nn.Conv2d):
|
| 481 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 482 |
+
fan_out //= m.groups
|
| 483 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 484 |
+
if m.bias is not None:
|
| 485 |
+
m.bias.data.zero_()
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class ViTSubBlock(ViTBlock):
|
| 489 |
+
"""A block of Vision Transformer."""
|
| 490 |
+
|
| 491 |
+
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1):
|
| 492 |
+
super().__init__(dim=dim, num_heads=8, mlp_ratio=mlp_ratio, qkv_bias=True,
|
| 493 |
+
attn_drop=drop, proj_drop=0, drop_path=drop_path, act_layer=nn.GELU, norm_layer=nn.LayerNorm)
|
| 494 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 495 |
+
self.apply(self._init_weights)
|
| 496 |
+
|
| 497 |
+
def _init_weights(self, m):
|
| 498 |
+
if isinstance(m, nn.Linear):
|
| 499 |
+
trunc_normal_(m.weight, std=.02)
|
| 500 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 501 |
+
nn.init.constant_(m.bias, 0)
|
| 502 |
+
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 503 |
+
nn.init.constant_(m.bias, 0)
|
| 504 |
+
nn.init.constant_(m.weight, 1.0)
|
| 505 |
+
|
| 506 |
+
@torch.jit.ignore
|
| 507 |
+
def no_weight_decay(self):
|
| 508 |
+
return {}
|
| 509 |
+
|
| 510 |
+
def forward(self, x):
|
| 511 |
+
B, C, H, W = x.shape
|
| 512 |
+
x = x.flatten(2).transpose(1, 2)
|
| 513 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 514 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 515 |
+
return x.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class TemporalAttention(nn.Module):
|
| 519 |
+
"""A Temporal Attention block for Temporal Attention Unit"""
|
| 520 |
+
|
| 521 |
+
def __init__(self, d_model, kernel_size=21, attn_shortcut=True):
|
| 522 |
+
super().__init__()
|
| 523 |
+
|
| 524 |
+
self.proj_1 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
|
| 525 |
+
self.activation = nn.GELU() # GELU
|
| 526 |
+
self.spatial_gating_unit = TemporalAttentionModule(d_model, kernel_size)
|
| 527 |
+
self.proj_2 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
|
| 528 |
+
self.attn_shortcut = attn_shortcut
|
| 529 |
+
|
| 530 |
+
def forward(self, x):
|
| 531 |
+
if self.attn_shortcut:
|
| 532 |
+
shortcut = x.clone()
|
| 533 |
+
x = self.proj_1(x)
|
| 534 |
+
x = self.activation(x)
|
| 535 |
+
x = self.spatial_gating_unit(x)
|
| 536 |
+
x = self.proj_2(x)
|
| 537 |
+
if self.attn_shortcut:
|
| 538 |
+
x = x + shortcut
|
| 539 |
+
return x
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class TemporalAttentionModule(nn.Module):
|
| 543 |
+
"""Large Kernel Attention for SimVP"""
|
| 544 |
+
|
| 545 |
+
def __init__(self, dim, kernel_size, dilation=3, reduction=16):
|
| 546 |
+
super().__init__()
|
| 547 |
+
d_k = 2 * dilation - 1
|
| 548 |
+
d_p = (d_k - 1) // 2
|
| 549 |
+
dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1)
|
| 550 |
+
dd_p = (dilation * (dd_k - 1) // 2)
|
| 551 |
+
|
| 552 |
+
self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim)
|
| 553 |
+
self.conv_spatial = nn.Conv2d(
|
| 554 |
+
dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation)
|
| 555 |
+
self.conv1 = nn.Conv2d(dim, dim, 1)
|
| 556 |
+
|
| 557 |
+
self.reduction = max(dim // reduction, 4)
|
| 558 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 559 |
+
self.fc = nn.Sequential(
|
| 560 |
+
nn.Linear(dim, dim // self.reduction, bias=False), # reduction
|
| 561 |
+
nn.ReLU(True),
|
| 562 |
+
nn.Linear(dim // self.reduction, dim, bias=False), # expansion
|
| 563 |
+
nn.Sigmoid()
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
def forward(self, x):
|
| 567 |
+
u = x.clone()
|
| 568 |
+
attn = self.conv0(x) # depth-wise conv
|
| 569 |
+
attn = self.conv_spatial(attn) # depth-wise dilation convolution
|
| 570 |
+
f_x = self.conv1(attn) # 1x1 conv
|
| 571 |
+
# append a se operation
|
| 572 |
+
b, c, _, _ = x.size()
|
| 573 |
+
se_atten = self.avg_pool(x).view(b, c)
|
| 574 |
+
se_atten = self.fc(se_atten).view(b, c, 1, 1)
|
| 575 |
+
return se_atten * f_x * u
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class TAUSubBlock(GASubBlock):
|
| 579 |
+
"""A TAUBlock (tau) for Temporal Attention Unit"""
|
| 580 |
+
|
| 581 |
+
def __init__(self, dim, kernel_size=21, mlp_ratio=4.,
|
| 582 |
+
drop=0., drop_path=0.1, init_value=1e-2, act_layer=nn.GELU):
|
| 583 |
+
super().__init__(dim=dim, kernel_size=kernel_size, mlp_ratio=mlp_ratio,
|
| 584 |
+
drop=drop, drop_path=drop_path, init_value=init_value, act_layer=act_layer)
|
| 585 |
+
|
| 586 |
+
self.attn = TemporalAttention(dim, kernel_size)
|
utilpack/swinlstm_modules.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from timm.models.swin_transformer import SwinTransformerBlock, window_reverse, PatchEmbed, PatchMerging, window_partition
|
| 4 |
+
from timm.layers import to_2tuple
|
| 5 |
+
|
| 6 |
+
class SwinLSTMCell(nn.Module):
|
| 7 |
+
def __init__(self, dim, input_resolution, num_heads, window_size, depth,
|
| 8 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
| 9 |
+
drop_path=0., norm_layer=nn.LayerNorm, flag=None):
|
| 10 |
+
"""
|
| 11 |
+
Args:
|
| 12 |
+
flag: 0 UpSample 1 DownSample 2 STconvert
|
| 13 |
+
"""
|
| 14 |
+
super(SwinLSTMCell, self).__init__()
|
| 15 |
+
|
| 16 |
+
self.STBs = nn.ModuleList(
|
| 17 |
+
STB(i, dim=dim, input_resolution=input_resolution, depth=depth,
|
| 18 |
+
num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
|
| 19 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
|
| 20 |
+
drop_path=drop_path, norm_layer=norm_layer, flag=flag)
|
| 21 |
+
for i in range(depth))
|
| 22 |
+
|
| 23 |
+
def forward(self, xt, hidden_states):
|
| 24 |
+
"""
|
| 25 |
+
Args:
|
| 26 |
+
xt: input for t period
|
| 27 |
+
hidden_states: [hx, cx] hidden_states for t-1 period
|
| 28 |
+
"""
|
| 29 |
+
if hidden_states is None:
|
| 30 |
+
B, L, C = xt.shape
|
| 31 |
+
hx = torch.zeros(B, L, C).to(xt.device)
|
| 32 |
+
cx = torch.zeros(B, L, C).to(xt.device)
|
| 33 |
+
|
| 34 |
+
else:
|
| 35 |
+
hx, cx = hidden_states
|
| 36 |
+
|
| 37 |
+
outputs = []
|
| 38 |
+
for index, layer in enumerate(self.STBs):
|
| 39 |
+
if index == 0:
|
| 40 |
+
x = layer(xt, hx)
|
| 41 |
+
outputs.append(x)
|
| 42 |
+
else:
|
| 43 |
+
if index % 2 == 0:
|
| 44 |
+
x = layer(outputs[-1], xt)
|
| 45 |
+
outputs.append(x)
|
| 46 |
+
if index % 2 == 1:
|
| 47 |
+
x = layer(outputs[-1], None)
|
| 48 |
+
outputs.append(x)
|
| 49 |
+
|
| 50 |
+
o_t = outputs[-1]
|
| 51 |
+
Ft = torch.sigmoid(o_t)
|
| 52 |
+
|
| 53 |
+
cell = torch.tanh(o_t)
|
| 54 |
+
|
| 55 |
+
Ct = Ft * (cx + cell)
|
| 56 |
+
Ht = Ft * torch.tanh(Ct)
|
| 57 |
+
|
| 58 |
+
return Ht, (Ht, Ct)
|
| 59 |
+
|
| 60 |
+
class STB(SwinTransformerBlock):
|
| 61 |
+
def __init__(self, index, dim, input_resolution, depth, num_heads, window_size,
|
| 62 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
| 63 |
+
drop_path=0., norm_layer=nn.LayerNorm, flag=None):
|
| 64 |
+
if flag == 0:
|
| 65 |
+
drop_path = drop_path[depth - index - 1]
|
| 66 |
+
elif flag == 1:
|
| 67 |
+
drop_path = drop_path[index]
|
| 68 |
+
elif flag == 2:
|
| 69 |
+
drop_path = drop_path
|
| 70 |
+
super(STB, self).__init__(dim=dim, input_resolution=input_resolution,
|
| 71 |
+
num_heads=num_heads, window_size=window_size,
|
| 72 |
+
shift_size=0 if (index % 2 == 0) else window_size // 2,
|
| 73 |
+
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
| 74 |
+
drop=drop, attn_drop=attn_drop,
|
| 75 |
+
drop_path=drop_path,
|
| 76 |
+
norm_layer=norm_layer)
|
| 77 |
+
self.red = nn.Linear(2 * dim, dim)
|
| 78 |
+
|
| 79 |
+
def forward(self, x, hx=None):
|
| 80 |
+
H, W = self.input_resolution
|
| 81 |
+
B, L, C = x.shape
|
| 82 |
+
assert L == H * W, "input feature has wrong size"
|
| 83 |
+
|
| 84 |
+
shortcut = x
|
| 85 |
+
x = self.norm1(x)
|
| 86 |
+
if hx is not None:
|
| 87 |
+
hx = self.norm1(hx)
|
| 88 |
+
x = torch.cat((x, hx), -1)
|
| 89 |
+
x = self.red(x)
|
| 90 |
+
x = x.view(B, H, W, C)
|
| 91 |
+
|
| 92 |
+
# cyclic shift
|
| 93 |
+
if self.shift_size > 0:
|
| 94 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 95 |
+
else:
|
| 96 |
+
shifted_x = x
|
| 97 |
+
|
| 98 |
+
# partition windows
|
| 99 |
+
x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C
|
| 100 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C
|
| 101 |
+
|
| 102 |
+
# W-MSA/SW-MSA
|
| 103 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C
|
| 104 |
+
|
| 105 |
+
# merge windows
|
| 106 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 107 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
| 108 |
+
|
| 109 |
+
# reverse cyclic shift
|
| 110 |
+
if self.shift_size > 0:
|
| 111 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 112 |
+
else:
|
| 113 |
+
x = shifted_x
|
| 114 |
+
x = x.view(B, H * W, C)
|
| 115 |
+
|
| 116 |
+
# FFN
|
| 117 |
+
x = shortcut + self.drop_path(x)
|
| 118 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 119 |
+
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
class PatchInflated(nn.Module):
|
| 123 |
+
r""" Tensor to Patch Inflating
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
in_chans (int): Number of input image channels.
|
| 127 |
+
embed_dim (int): Number of linear projection output channels.
|
| 128 |
+
input_resolution (tuple[int]): Input resulotion.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, in_chans, embed_dim, input_resolution, stride=2, padding=1, output_padding=1):
|
| 132 |
+
super(PatchInflated, self).__init__()
|
| 133 |
+
|
| 134 |
+
stride = to_2tuple(stride)
|
| 135 |
+
padding = to_2tuple(padding)
|
| 136 |
+
output_padding = to_2tuple(output_padding)
|
| 137 |
+
self.input_resolution = input_resolution
|
| 138 |
+
|
| 139 |
+
self.Conv = nn.ConvTranspose2d(in_channels=embed_dim, out_channels=in_chans, kernel_size=(3, 3),
|
| 140 |
+
stride=stride, padding=padding, output_padding=output_padding)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
H, W = self.input_resolution
|
| 144 |
+
B, L, C = x.shape
|
| 145 |
+
assert L == H * W, "input feature has wrong size"
|
| 146 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
| 147 |
+
|
| 148 |
+
x = x.view(B, H, W, C)
|
| 149 |
+
x = x.permute(0, 3, 1, 2)
|
| 150 |
+
x = self.Conv(x)
|
| 151 |
+
|
| 152 |
+
return x
|
| 153 |
+
|
| 154 |
+
class PatchExpanding(nn.Module):
|
| 155 |
+
r""" Patch Expanding Layer.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
| 159 |
+
dim (int): Number of input channels.
|
| 160 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
|
| 164 |
+
super(PatchExpanding, self).__init__()
|
| 165 |
+
self.input_resolution = input_resolution
|
| 166 |
+
self.dim = dim
|
| 167 |
+
self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
|
| 168 |
+
self.norm = norm_layer(dim // dim_scale)
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
H, W = self.input_resolution
|
| 172 |
+
x = self.expand(x)
|
| 173 |
+
B, L, C = x.shape
|
| 174 |
+
assert L == H * W, "input feature has wrong size"
|
| 175 |
+
|
| 176 |
+
x = x.view(B, H, W, C)
|
| 177 |
+
x = x.reshape(B, H, W, 2, 2, C // 4)
|
| 178 |
+
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H * 2, W * 2, C // 4)
|
| 179 |
+
x = x.view(B, -1, C // 4)
|
| 180 |
+
x = self.norm(x)
|
| 181 |
+
|
| 182 |
+
return x
|
| 183 |
+
|
| 184 |
+
class UpSample(nn.Module):
|
| 185 |
+
def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_upsample, num_heads, window_size, mlp_ratio=4.,
|
| 186 |
+
qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
| 187 |
+
norm_layer=nn.LayerNorm, flag=0):
|
| 188 |
+
super(UpSample, self).__init__()
|
| 189 |
+
|
| 190 |
+
self.img_size = img_size
|
| 191 |
+
self.num_layers = len(depths_upsample)
|
| 192 |
+
self.embed_dim = embed_dim
|
| 193 |
+
self.mlp_ratio = mlp_ratio
|
| 194 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=nn.LayerNorm)
|
| 195 |
+
patches_resolution = self.patch_embed.grid_size
|
| 196 |
+
self.Unembed = PatchInflated(in_chans=in_chans, embed_dim=embed_dim, input_resolution=patches_resolution)
|
| 197 |
+
|
| 198 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_upsample))]
|
| 199 |
+
|
| 200 |
+
self.layers = nn.ModuleList()
|
| 201 |
+
self.upsample = nn.ModuleList()
|
| 202 |
+
|
| 203 |
+
for i_layer in range(self.num_layers):
|
| 204 |
+
resolution1 = (patches_resolution[0] // (2 ** (self.num_layers - i_layer)))
|
| 205 |
+
resolution2 = (patches_resolution[1] // (2 ** (self.num_layers - i_layer)))
|
| 206 |
+
|
| 207 |
+
dimension = int(embed_dim * 2 ** (self.num_layers - i_layer))
|
| 208 |
+
upsample = PatchExpanding(input_resolution=(resolution1, resolution2), dim=dimension)
|
| 209 |
+
|
| 210 |
+
layer = SwinLSTMCell(dim=dimension, input_resolution=(resolution1, resolution2),
|
| 211 |
+
depth=depths_upsample[(self.num_layers - 1 - i_layer)],
|
| 212 |
+
num_heads=num_heads[(self.num_layers - 1 - i_layer)],
|
| 213 |
+
window_size=window_size,
|
| 214 |
+
mlp_ratio=self.mlp_ratio,
|
| 215 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 216 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 217 |
+
drop_path=dpr[sum(depths_upsample[:(self.num_layers - 1 - i_layer)]):
|
| 218 |
+
sum(depths_upsample[:(self.num_layers - 1 - i_layer) + 1])],
|
| 219 |
+
norm_layer=norm_layer, flag=flag)
|
| 220 |
+
|
| 221 |
+
self.layers.append(layer)
|
| 222 |
+
self.upsample.append(upsample)
|
| 223 |
+
|
| 224 |
+
def forward(self, x, y):
|
| 225 |
+
hidden_states_up = []
|
| 226 |
+
|
| 227 |
+
for index, layer in enumerate(self.layers):
|
| 228 |
+
x, hidden_state = layer(x, y[index])
|
| 229 |
+
x = self.upsample[index](x)
|
| 230 |
+
hidden_states_up.append(hidden_state)
|
| 231 |
+
|
| 232 |
+
x = torch.sigmoid(self.Unembed(x))
|
| 233 |
+
|
| 234 |
+
return hidden_states_up, x
|
| 235 |
+
|
| 236 |
+
class DownSample(nn.Module):
|
| 237 |
+
def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_downsample, num_heads, window_size,
|
| 238 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
| 239 |
+
norm_layer=nn.LayerNorm, flag=1):
|
| 240 |
+
super(DownSample, self).__init__()
|
| 241 |
+
|
| 242 |
+
self.num_layers = len(depths_downsample)
|
| 243 |
+
self.embed_dim = embed_dim
|
| 244 |
+
self.mlp_ratio = mlp_ratio
|
| 245 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=nn.LayerNorm)
|
| 246 |
+
patches_resolution = self.patch_embed.grid_size
|
| 247 |
+
|
| 248 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_downsample))]
|
| 249 |
+
|
| 250 |
+
self.layers = nn.ModuleList()
|
| 251 |
+
self.downsample = nn.ModuleList()
|
| 252 |
+
|
| 253 |
+
for i_layer in range(self.num_layers):
|
| 254 |
+
downsample = PatchMerging(input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
| 255 |
+
patches_resolution[1] // (2 ** i_layer)),
|
| 256 |
+
dim=int(embed_dim * 2 ** i_layer))
|
| 257 |
+
|
| 258 |
+
layer = SwinLSTMCell(dim=int(embed_dim * 2 ** i_layer),
|
| 259 |
+
input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
| 260 |
+
patches_resolution[1] // (2 ** i_layer)),
|
| 261 |
+
depth=depths_downsample[i_layer],
|
| 262 |
+
num_heads=num_heads[i_layer],
|
| 263 |
+
window_size=window_size,
|
| 264 |
+
mlp_ratio=self.mlp_ratio,
|
| 265 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 266 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 267 |
+
drop_path=dpr[sum(depths_downsample[:i_layer]):sum(depths_downsample[:i_layer + 1])],
|
| 268 |
+
norm_layer=norm_layer, flag=flag)
|
| 269 |
+
|
| 270 |
+
self.layers.append(layer)
|
| 271 |
+
self.downsample.append(downsample)
|
| 272 |
+
|
| 273 |
+
def forward(self, x, y):
|
| 274 |
+
|
| 275 |
+
x = self.patch_embed(x)
|
| 276 |
+
|
| 277 |
+
hidden_states_down = []
|
| 278 |
+
|
| 279 |
+
for index, layer in enumerate(self.layers):
|
| 280 |
+
x, hidden_state = layer(x, y[index])
|
| 281 |
+
x = self.downsample[index](x)
|
| 282 |
+
hidden_states_down.append(hidden_state)
|
| 283 |
+
|
| 284 |
+
return hidden_states_down, x
|
| 285 |
+
|
| 286 |
+
class STconvert(nn.Module):
|
| 287 |
+
def __init__(self, img_size, patch_size, in_chans, embed_dim, depths, num_heads,
|
| 288 |
+
window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0.,
|
| 289 |
+
attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, flag=2):
|
| 290 |
+
super(STconvert, self).__init__()
|
| 291 |
+
|
| 292 |
+
self.embed_dim = embed_dim
|
| 293 |
+
self.mlp_ratio = mlp_ratio
|
| 294 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,
|
| 295 |
+
in_chans=in_chans, embed_dim=embed_dim,
|
| 296 |
+
norm_layer=norm_layer)
|
| 297 |
+
patches_resolution = self.patch_embed.grid_size
|
| 298 |
+
|
| 299 |
+
self.patch_inflated = PatchInflated(in_chans=in_chans, embed_dim=embed_dim,
|
| 300 |
+
input_resolution=patches_resolution)
|
| 301 |
+
|
| 302 |
+
self.layer = SwinLSTMCell(dim=embed_dim,
|
| 303 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
| 304 |
+
depth=depths, num_heads=num_heads,
|
| 305 |
+
window_size=window_size, mlp_ratio=mlp_ratio,
|
| 306 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 307 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 308 |
+
drop_path=drop_path_rate, norm_layer=norm_layer,
|
| 309 |
+
flag=flag)
|
| 310 |
+
def forward(self, x, h=None):
|
| 311 |
+
x = self.patch_embed(x)
|
| 312 |
+
|
| 313 |
+
x, hidden_state = self.layer(x, h)
|
| 314 |
+
|
| 315 |
+
x = torch.sigmoid(self.patch_inflated(x))
|
| 316 |
+
|
| 317 |
+
return x, hidden_state
|
utilpack/wast_modules.py
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, pywt
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from functools import partial
|
| 5 |
+
from itertools import accumulate
|
| 6 |
+
from timm.layers import DropPath, activations
|
| 7 |
+
from timm.models._efficientnet_blocks import SqueezeExcite, InvertedResidual
|
| 8 |
+
|
| 9 |
+
# version adaptation for PyTorch > 1.7.1
|
| 10 |
+
IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.'))) > (1, 7, 1)
|
| 11 |
+
if IS_HIGH_VERSION:
|
| 12 |
+
import torch.fft
|
| 13 |
+
|
| 14 |
+
class HighFocalFrequencyLoss(nn.Module):
|
| 15 |
+
""" Example:
|
| 16 |
+
fake = torch.randn(4, 3, 128, 64)
|
| 17 |
+
real = torch.randn(4, 3, 128, 64)
|
| 18 |
+
hffl = HighFocalFrequencyLoss()
|
| 19 |
+
|
| 20 |
+
loss = hffl(fake, real)
|
| 21 |
+
print(loss)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, loss_weight=0.001, level=1, tau=0.1, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=True, batch_matrix=False):
|
| 25 |
+
super(HighFocalFrequencyLoss, self).__init__()
|
| 26 |
+
self.loss_weight = loss_weight
|
| 27 |
+
self.alpha = alpha
|
| 28 |
+
self.patch_factor = patch_factor
|
| 29 |
+
self.ave_spectrum = ave_spectrum
|
| 30 |
+
self.log_matrix = log_matrix
|
| 31 |
+
self.batch_matrix = batch_matrix
|
| 32 |
+
self.level = level
|
| 33 |
+
self.tau = tau
|
| 34 |
+
self.DWT = WaveletTransform2D().cuda()
|
| 35 |
+
|
| 36 |
+
def tensor2freq(self, x):
|
| 37 |
+
# crop image patches
|
| 38 |
+
patch_factor = self.patch_factor
|
| 39 |
+
_, _, h, w = x.shape
|
| 40 |
+
assert h % patch_factor == 0 and w % patch_factor == 0, (
|
| 41 |
+
'Patch factor should be divisible by image height and width')
|
| 42 |
+
patch_list = []
|
| 43 |
+
patch_h = h // patch_factor
|
| 44 |
+
patch_w = w // patch_factor
|
| 45 |
+
for i in range(patch_factor):
|
| 46 |
+
for j in range(patch_factor):
|
| 47 |
+
patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
|
| 48 |
+
|
| 49 |
+
# stack to patch tensor
|
| 50 |
+
y = torch.stack(patch_list, 1)
|
| 51 |
+
|
| 52 |
+
# perform 2D DFT (real-to-complex, orthonormalization)
|
| 53 |
+
if IS_HIGH_VERSION:
|
| 54 |
+
freq = torch.fft.fft2(y, norm='ortho')
|
| 55 |
+
freq = torch.stack([freq.real, freq.imag], -1)
|
| 56 |
+
else:
|
| 57 |
+
freq = torch.rfft(y, 2, onesided=False, normalized=True)
|
| 58 |
+
return freq
|
| 59 |
+
|
| 60 |
+
def build_freq_mask(self, shape):
|
| 61 |
+
H, W = shape[-2:]
|
| 62 |
+
radius = self.tau * max(H, W)
|
| 63 |
+
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
|
| 64 |
+
|
| 65 |
+
mask = torch.ones_like(X, dtype=torch.float32).cuda()
|
| 66 |
+
|
| 67 |
+
centers = [(0, 0), (0, W - 1), (H - 1, 0), (H - 1, W - 1)]
|
| 68 |
+
|
| 69 |
+
for center in centers:
|
| 70 |
+
distance = torch.sqrt((X - center[1]) ** 2 + (Y - center[0]) ** 2)
|
| 71 |
+
mask[distance <= radius] = 0
|
| 72 |
+
return mask
|
| 73 |
+
|
| 74 |
+
def loss_formulation(self, recon_freq, real_freq, matrix=None):
|
| 75 |
+
# spectrum weight matrix
|
| 76 |
+
if matrix is not None:
|
| 77 |
+
# if the matrix is predefined
|
| 78 |
+
weight_matrix = matrix.detach()
|
| 79 |
+
else:
|
| 80 |
+
# if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance
|
| 81 |
+
matrix_tmp = (recon_freq - real_freq) ** 2
|
| 82 |
+
matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
|
| 83 |
+
|
| 84 |
+
# whether to adjust the spectrum weight matrix by logarithm
|
| 85 |
+
if self.log_matrix:
|
| 86 |
+
matrix_tmp = torch.log(matrix_tmp + 1.0)
|
| 87 |
+
|
| 88 |
+
# whether to calculate the spectrum weight matrix using batch-based statistics
|
| 89 |
+
if self.batch_matrix:
|
| 90 |
+
matrix_tmp = matrix_tmp / matrix_tmp.max()
|
| 91 |
+
else:
|
| 92 |
+
matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
|
| 93 |
+
|
| 94 |
+
matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
|
| 95 |
+
matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
|
| 96 |
+
weight_matrix = matrix_tmp.clone().detach()
|
| 97 |
+
|
| 98 |
+
assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
|
| 99 |
+
'The values of spectrum weight matrix should be in the range [0, 1], '
|
| 100 |
+
'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
|
| 101 |
+
|
| 102 |
+
# frequency distance using (squared) Euclidean distance
|
| 103 |
+
tmp = (recon_freq - real_freq) ** 2
|
| 104 |
+
freq_distance = tmp[..., 0] + tmp[..., 1]
|
| 105 |
+
|
| 106 |
+
# dynamic spectrum weighting (Hadamard product)
|
| 107 |
+
mask = self.build_freq_mask(weight_matrix.shape)
|
| 108 |
+
loss = weight_matrix * freq_distance * mask
|
| 109 |
+
return torch.mean(loss)
|
| 110 |
+
|
| 111 |
+
def frequency_loss(self, pred, target, matrix=None):
|
| 112 |
+
"""Forward function to calculate focal frequency loss.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor.
|
| 116 |
+
target (torch.Tensor): of shape (N, C, H, W). Target tensor.
|
| 117 |
+
matrix (torch.Tensor, optional): Element-wise spectrum weight matrix.
|
| 118 |
+
Default: None (If set to None: calculated online, dynamic).
|
| 119 |
+
"""
|
| 120 |
+
pred_freq = self.tensor2freq(pred)
|
| 121 |
+
target_freq = self.tensor2freq(target)
|
| 122 |
+
|
| 123 |
+
# whether to use minibatch average spectrum
|
| 124 |
+
if self.ave_spectrum:
|
| 125 |
+
pred_freq = torch.mean(pred_freq, 0, keepdim=True)
|
| 126 |
+
target_freq = torch.mean(target_freq, 0, keepdim=True)
|
| 127 |
+
|
| 128 |
+
return self.loss_formulation(pred_freq, target_freq, matrix)
|
| 129 |
+
|
| 130 |
+
def forward(self, pred, target, matrix=None, **kwargs):
|
| 131 |
+
pred = rearrange(pred, 'b t c h w -> (b t) c h w') if kwargs["reshape"] is True else pred
|
| 132 |
+
target = rearrange(target, 'b t c h w -> (b t) c h w') if kwargs["reshape"] is True else target
|
| 133 |
+
|
| 134 |
+
loss = 0
|
| 135 |
+
for level in range(self.level):
|
| 136 |
+
pred, _, _, _ = self.DWT(pred)
|
| 137 |
+
target, _, _, _ = self.DWT(target)
|
| 138 |
+
loss += self.frequency_loss(pred, target, matrix)
|
| 139 |
+
return loss * self.loss_weight
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class WaveletTransform2D(nn.Module):
|
| 143 |
+
"""Compute a two-dimensional wavelet transform.
|
| 144 |
+
loss = nn.MSELoss()
|
| 145 |
+
data = torch.rand(1, 3, 128, 256)
|
| 146 |
+
DWT = WaveletTransform2D()
|
| 147 |
+
IDWT = WaveletTransform2D(inverse=True)
|
| 148 |
+
|
| 149 |
+
LL, LH, HL, HH = DWT(data)
|
| 150 |
+
recdata = IDWT([LL, LH, HL, HH])
|
| 151 |
+
print(loss(data, recdata))
|
| 152 |
+
"""
|
| 153 |
+
def __init__(self, inverse=False, wavelet="haar", mode="constant"):
|
| 154 |
+
super(WaveletTransform2D, self).__init__()
|
| 155 |
+
self.mode = mode
|
| 156 |
+
wavelet = pywt.Wavelet(wavelet)
|
| 157 |
+
|
| 158 |
+
if isinstance(wavelet, tuple):
|
| 159 |
+
dec_lo, dec_hi, rec_lo, rec_hi = wavelet
|
| 160 |
+
else:
|
| 161 |
+
dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank
|
| 162 |
+
|
| 163 |
+
self.inverse = inverse
|
| 164 |
+
if inverse is False:
|
| 165 |
+
dec_lo = torch.tensor(dec_lo).flip(-1).unsqueeze(0)
|
| 166 |
+
dec_hi = torch.tensor(dec_hi).flip(-1).unsqueeze(0)
|
| 167 |
+
self.build_filters(dec_lo, dec_hi)
|
| 168 |
+
else:
|
| 169 |
+
rec_lo = torch.tensor(rec_lo).unsqueeze(0)
|
| 170 |
+
rec_hi = torch.tensor(rec_hi).unsqueeze(0)
|
| 171 |
+
self.build_filters(rec_lo, rec_hi)
|
| 172 |
+
|
| 173 |
+
def build_filters(self, lo, hi):
|
| 174 |
+
# construct 2d filter
|
| 175 |
+
self.dim_size = lo.shape[-1]
|
| 176 |
+
ll = self.outer(lo, lo)
|
| 177 |
+
lh = self.outer(hi, lo)
|
| 178 |
+
hl = self.outer(lo, hi)
|
| 179 |
+
hh = self.outer(hi, hi)
|
| 180 |
+
filters = torch.stack([ll, lh, hl, hh],dim=0)
|
| 181 |
+
filters = filters.unsqueeze(1)
|
| 182 |
+
self.register_buffer('filters', filters) # [4, 1, height, width]
|
| 183 |
+
|
| 184 |
+
def outer(self, a: torch.Tensor, b: torch.Tensor):
|
| 185 |
+
"""Torch implementation of numpy's outer for 1d vectors."""
|
| 186 |
+
a_flat = torch.reshape(a, [-1])
|
| 187 |
+
b_flat = torch.reshape(b, [-1])
|
| 188 |
+
a_mul = torch.unsqueeze(a_flat, dim=-1)
|
| 189 |
+
b_mul = torch.unsqueeze(b_flat, dim=0)
|
| 190 |
+
return a_mul * b_mul
|
| 191 |
+
|
| 192 |
+
def get_pad(self, data_len: int, filter_len: int):
|
| 193 |
+
padr = (2 * filter_len - 3) // 2
|
| 194 |
+
padl = (2 * filter_len - 3) // 2
|
| 195 |
+
# pad to even singal length.
|
| 196 |
+
if data_len % 2 != 0:
|
| 197 |
+
padr += 1
|
| 198 |
+
return padr, padl
|
| 199 |
+
|
| 200 |
+
def adaptive_pad(self, data):
|
| 201 |
+
padb, padt = self.get_pad(data.shape[-2], self.dim_size)
|
| 202 |
+
padr, padl = self.get_pad(data.shape[-1], self.dim_size)
|
| 203 |
+
|
| 204 |
+
data_pad = torch.nn.functional.pad(data, [padl, padr, padt, padb], mode=self.mode)
|
| 205 |
+
return data_pad
|
| 206 |
+
|
| 207 |
+
def forward(self, data):
|
| 208 |
+
if self.inverse is False:
|
| 209 |
+
b, c, h, w = data.shape
|
| 210 |
+
dec_res = []
|
| 211 |
+
data = self.adaptive_pad(data)
|
| 212 |
+
for filter in self.filters:
|
| 213 |
+
dec_res.append(torch.nn.functional.conv2d(data, filter.repeat(c, 1, 1, 1), stride=2, groups=c))
|
| 214 |
+
return dec_res
|
| 215 |
+
else:
|
| 216 |
+
b, c, h, w = data[0].shape
|
| 217 |
+
data = torch.stack(data, dim=2).reshape(b, -1, h, w)
|
| 218 |
+
rec_res = torch.nn.functional.conv_transpose2d(data, self.filters.repeat(c, 1, 1, 1), stride=2, groups=c)
|
| 219 |
+
return rec_res
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class WaveletTransform3D(nn.Module):
|
| 223 |
+
"""Compute a three-dimensional wavelet transform.
|
| 224 |
+
Example:
|
| 225 |
+
loss = nn.MSELoss()
|
| 226 |
+
data = torch.rand(1, 3, 10, 128, 256)
|
| 227 |
+
DWT = WaveletTransform3D()
|
| 228 |
+
IDWT = WaveletTransform3D(inverse=True)
|
| 229 |
+
|
| 230 |
+
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = DWT(data)
|
| 231 |
+
recdata = IDWT([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH])
|
| 232 |
+
print(loss(data, recdata))
|
| 233 |
+
|
| 234 |
+
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = DWT_3D(data)
|
| 235 |
+
recdata = IDWT_3D(LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH)
|
| 236 |
+
print(loss(data, recdata))
|
| 237 |
+
"""
|
| 238 |
+
def __init__(self, inverse=False, wavelet="haar", mode="constant"):
|
| 239 |
+
super(WaveletTransform3D, self).__init__()
|
| 240 |
+
self.mode = mode
|
| 241 |
+
wavelet = pywt.Wavelet(wavelet)
|
| 242 |
+
|
| 243 |
+
if isinstance(wavelet, tuple):
|
| 244 |
+
dec_lo, dec_hi, rec_lo, rec_hi = wavelet
|
| 245 |
+
else:
|
| 246 |
+
dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank
|
| 247 |
+
|
| 248 |
+
self.inverse = inverse
|
| 249 |
+
if inverse is False:
|
| 250 |
+
dec_lo = torch.tensor(dec_lo).flip(-1).unsqueeze(0)
|
| 251 |
+
dec_hi = torch.tensor(dec_hi).flip(-1).unsqueeze(0)
|
| 252 |
+
self.build_filters(dec_lo, dec_hi)
|
| 253 |
+
else:
|
| 254 |
+
rec_lo = torch.tensor(rec_lo).unsqueeze(0)
|
| 255 |
+
rec_hi = torch.tensor(rec_hi).unsqueeze(0)
|
| 256 |
+
self.build_filters(rec_lo, rec_hi)
|
| 257 |
+
|
| 258 |
+
def build_filters(self, lo, hi):
|
| 259 |
+
# construct 3d filter
|
| 260 |
+
self.dim_size = lo.shape[-1]
|
| 261 |
+
size = [self.dim_size] * 3
|
| 262 |
+
lll = self.outer(lo, self.outer(lo, lo)).reshape(size)
|
| 263 |
+
llh = self.outer(lo, self.outer(lo, hi)).reshape(size)
|
| 264 |
+
lhl = self.outer(lo, self.outer(hi, lo)).reshape(size)
|
| 265 |
+
lhh = self.outer(lo, self.outer(hi, hi)).reshape(size)
|
| 266 |
+
hll = self.outer(hi, self.outer(lo, lo)).reshape(size)
|
| 267 |
+
hlh = self.outer(hi, self.outer(lo, hi)).reshape(size)
|
| 268 |
+
hhl = self.outer(hi, self.outer(hi, lo)).reshape(size)
|
| 269 |
+
hhh = self.outer(hi, self.outer(hi, hi)).reshape(size)
|
| 270 |
+
filters = torch.stack([lll, llh, lhl, lhh, hll, hlh, hhl, hhh], dim=0)
|
| 271 |
+
filters = filters.unsqueeze(1)
|
| 272 |
+
self.register_buffer('filters', filters) # [8, 1, length, height, width]
|
| 273 |
+
|
| 274 |
+
def outer(self, a: torch.Tensor, b: torch.Tensor):
|
| 275 |
+
"""Torch implementation of numpy's outer for 1d vectors."""
|
| 276 |
+
a_flat = torch.reshape(a, [-1])
|
| 277 |
+
b_flat = torch.reshape(b, [-1])
|
| 278 |
+
a_mul = torch.unsqueeze(a_flat, dim=-1)
|
| 279 |
+
b_mul = torch.unsqueeze(b_flat, dim=0)
|
| 280 |
+
return a_mul * b_mul
|
| 281 |
+
|
| 282 |
+
def get_pad(self, data_len: int, filter_len: int):
|
| 283 |
+
padr = (2 * filter_len - 3) // 2
|
| 284 |
+
padl = (2 * filter_len - 3) // 2
|
| 285 |
+
# pad to even singal length.
|
| 286 |
+
if data_len % 2 != 0:
|
| 287 |
+
padr += 1
|
| 288 |
+
return padr, padl
|
| 289 |
+
|
| 290 |
+
def adaptive_pad(self, data):
|
| 291 |
+
pad_back, pad_front = self.get_pad(data.shape[-3], self.dim_size)
|
| 292 |
+
pad_bottom, pad_top = self.get_pad(data.shape[-2], self.dim_size)
|
| 293 |
+
pad_right, pad_left = self.get_pad(data.shape[-1], self.dim_size)
|
| 294 |
+
data_pad = torch.nn.functional.pad(
|
| 295 |
+
data, [pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back], mode=self.mode)
|
| 296 |
+
return data_pad
|
| 297 |
+
|
| 298 |
+
def forward(self, data):
|
| 299 |
+
if self.inverse is False:
|
| 300 |
+
b, c, t, h, w = data.shape
|
| 301 |
+
dec_res = []
|
| 302 |
+
data = self.adaptive_pad(data)
|
| 303 |
+
for filter in self.filters:
|
| 304 |
+
dec_res.append(torch.nn.functional.conv3d(data, filter.repeat(c, 1, 1, 1, 1), stride=2, groups=c))
|
| 305 |
+
return dec_res
|
| 306 |
+
else:
|
| 307 |
+
b, c, t, h, w = data[0].shape
|
| 308 |
+
data = torch.stack(data, dim=2).reshape(b, -1, t, h, w)
|
| 309 |
+
rec_res = torch.nn.functional.conv_transpose3d(data, self.filters.repeat(c, 1, 1, 1, 1), stride=2, groups=c)
|
| 310 |
+
return rec_res
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class FrequencyAttention(nn.Module):
|
| 314 |
+
def __init__(self, in_dim, out_dim, reduction=32):
|
| 315 |
+
super(FrequencyAttention, self).__init__()
|
| 316 |
+
self.avgpool_h = nn.AdaptiveAvgPool2d((None, 1))
|
| 317 |
+
self.avgpool_w = nn.AdaptiveAvgPool2d((1, None))
|
| 318 |
+
|
| 319 |
+
hidden_dim = max(8, in_dim // reduction)
|
| 320 |
+
|
| 321 |
+
self.conv1 = nn.Conv2d(in_dim, hidden_dim, kernel_size=1, stride=1, padding=0)
|
| 322 |
+
self.bn1 = nn.BatchNorm2d(hidden_dim)
|
| 323 |
+
self.act = activations.HardSwish(inplace=True)
|
| 324 |
+
|
| 325 |
+
self.conv_h = nn.Conv2d(hidden_dim, out_dim, kernel_size=1, stride=1, padding=0)
|
| 326 |
+
self.conv_w = nn.Conv2d(hidden_dim, out_dim, kernel_size=1, stride=1, padding=0)
|
| 327 |
+
|
| 328 |
+
def forward(self, x):
|
| 329 |
+
identity = x
|
| 330 |
+
|
| 331 |
+
n, c, h, w = x.size()
|
| 332 |
+
x_h = self.avgpool_h(x) # b c h 1
|
| 333 |
+
x_w = self.avgpool_w(x).permute(0, 1, 3, 2) # b c w 1
|
| 334 |
+
|
| 335 |
+
y = torch.cat([x_h, x_w], dim=2) # b c (h+w) 1
|
| 336 |
+
y = self.conv1(y)
|
| 337 |
+
y = self.bn1(y)
|
| 338 |
+
y = self.act(y)
|
| 339 |
+
|
| 340 |
+
x_h, x_w = torch.split(y, [h, w], dim=2)
|
| 341 |
+
x_w = x_w.permute(0, 1, 3, 2)
|
| 342 |
+
|
| 343 |
+
a_h = self.conv_h(x_h).sigmoid()
|
| 344 |
+
a_w = self.conv_w(x_w).sigmoid()
|
| 345 |
+
|
| 346 |
+
out = identity * a_w * a_h
|
| 347 |
+
|
| 348 |
+
return out
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class TF_AwareBlock(nn.Module):
|
| 352 |
+
def __init__(self, dim, mlp_ratio=4., drop=0., ls_init_value=1e-2, drop_path=0.1, large_kernel=51, small_kernel=5):
|
| 353 |
+
super().__init__()
|
| 354 |
+
|
| 355 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 356 |
+
self.norm1 = nn.BatchNorm2d(dim)
|
| 357 |
+
self.norm2 = nn.BatchNorm2d(dim)
|
| 358 |
+
|
| 359 |
+
self.lk1 = nn.Sequential(
|
| 360 |
+
nn.Conv2d(dim, dim, kernel_size=(large_kernel, 5), groups=dim, padding="same"),
|
| 361 |
+
nn.BatchNorm2d(dim)
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
self.lk2 = nn.Sequential(
|
| 365 |
+
nn.Conv2d(dim, dim, kernel_size=(5, large_kernel), groups=dim, padding="same"),
|
| 366 |
+
nn.BatchNorm2d(dim)
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
self.sk = nn.Sequential(
|
| 370 |
+
nn.Conv2d(dim, dim, kernel_size=(small_kernel, small_kernel), groups=dim, padding="same"),
|
| 371 |
+
nn.BatchNorm2d(dim)
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
self.low_frequency_attn = FrequencyAttention(in_dim=dim, out_dim=dim, reduction=4)
|
| 375 |
+
self.high_frequency_attn = FrequencyAttention(in_dim=dim, out_dim=dim, reduction=4)
|
| 376 |
+
|
| 377 |
+
self.temporal_mixer = InvertedResidual(in_chs=dim, out_chs=dim, dw_kernel_size=7, exp_ratio=mlp_ratio,
|
| 378 |
+
se_layer=partial(SqueezeExcite, rd_ratio=0.25), noskip=True)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
self.layer_scale_1 = nn.Parameter(ls_init_value * torch.ones((dim)), requires_grad=True)
|
| 382 |
+
self.layer_scale_2 = nn.Parameter(ls_init_value * torch.ones((dim)), requires_grad=True)
|
| 383 |
+
|
| 384 |
+
@torch.jit.ignore
|
| 385 |
+
def no_weight_decay(self):
|
| 386 |
+
return {'layer_scale_1', 'layer_scale_2'}
|
| 387 |
+
|
| 388 |
+
def forward(self, x):
|
| 389 |
+
attn = self.norm1(x)
|
| 390 |
+
x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * (self.low_frequency_attn(self.lk1(attn) + self.lk2(attn)) + self.high_frequency_attn(self.sk(attn))))
|
| 391 |
+
x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.temporal_mixer(self.norm2(x)))
|
| 392 |
+
return x
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class TF_AwareBlocks(nn.Module):
|
| 396 |
+
def __init__(self, dim, num_blocks, drop_path, use_bottleneck=None, use_hid=False, mlp_ratio=4., drop=0., ls_init_value=1e-2, large_kernel=51, small_kernel=5):
|
| 397 |
+
super().__init__()
|
| 398 |
+
assert len(drop_path) == num_blocks, "drop_path list doesn't match num_blocks"
|
| 399 |
+
self.use_hid = use_hid
|
| 400 |
+
self.use_bottleneck = use_bottleneck
|
| 401 |
+
|
| 402 |
+
blocks = []
|
| 403 |
+
for i in range(num_blocks):
|
| 404 |
+
block = TF_AwareBlock(dim, mlp_ratio, drop, ls_init_value, drop_path[i], large_kernel, small_kernel)
|
| 405 |
+
blocks.append(block)
|
| 406 |
+
self.blocks = nn.Sequential(*blocks)
|
| 407 |
+
self.concat_block = nn.Conv2d(dim * 2, dim, 3, 1, 1) if use_hid==True else None
|
| 408 |
+
|
| 409 |
+
self.DWT = WaveletTransform3D(inverse=False) if use_bottleneck == "decompose" else None
|
| 410 |
+
self.IDWT = WaveletTransform3D(inverse=True) if use_bottleneck == "decompose" else None
|
| 411 |
+
|
| 412 |
+
def forward(self, x, skip=None): # b, c ,t, h, w
|
| 413 |
+
if self.concat_block is not None and self.use_bottleneck is None:
|
| 414 |
+
b, c, t, h, w = x.shape
|
| 415 |
+
x = rearrange(x, 'b c t h w -> b (c t) h w')
|
| 416 |
+
x = self.concat_block(torch.cat([x, skip], dim=1))
|
| 417 |
+
x = self.blocks(x)
|
| 418 |
+
x = rearrange(x, 'b (c t) h w -> b c t h w', t=t)
|
| 419 |
+
return x
|
| 420 |
+
elif self.concat_block is None and self.use_bottleneck is None:
|
| 421 |
+
b, c, t, h, w = x.shape
|
| 422 |
+
x = rearrange(x, 'b c t h w -> b (c t) h w')
|
| 423 |
+
x = skip= self.blocks(x)
|
| 424 |
+
x = rearrange(x, 'b (c t) h w -> b c t h w', t=t)
|
| 425 |
+
return x, skip
|
| 426 |
+
elif self.use_bottleneck is not None:
|
| 427 |
+
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = self.DWT(x) if self.use_bottleneck == "decompose" else [x, None, None, None, None, None, None, None]
|
| 428 |
+
b, c, t, h, w = LLL.shape
|
| 429 |
+
LLL = rearrange(LLL, 'b c t h w -> b (c t) h w')
|
| 430 |
+
LLL = self.blocks(LLL)
|
| 431 |
+
LLL = rearrange(LLL, 'b (c t) h w -> b c t h w', t=t)
|
| 432 |
+
x = self.IDWT([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH]) if self.use_bottleneck == "decompose" else LLL
|
| 433 |
+
return x
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class Wavelet_3D_Embedding(nn.Module):
|
| 438 |
+
def __init__(self, in_dim, out_dim, emb_dim=None):
|
| 439 |
+
super().__init__()
|
| 440 |
+
emb_dim = in_dim if emb_dim==None else emb_dim
|
| 441 |
+
self.conv_0 = nn.Sequential(nn.Conv3d(in_dim, in_dim, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),),
|
| 442 |
+
nn.BatchNorm3d(in_dim),
|
| 443 |
+
nn.GELU(),)
|
| 444 |
+
self.conv_1 = nn.Sequential(nn.Conv3d(in_dim, out_dim, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),),
|
| 445 |
+
nn.BatchNorm3d(out_dim),
|
| 446 |
+
nn.GELU(),)
|
| 447 |
+
|
| 448 |
+
self.conv_emb = nn.Conv3d(emb_dim * 4, out_dim, kernel_size=(3, 3, 3),stride=(1, 1, 1),padding=(1, 1, 1),)
|
| 449 |
+
|
| 450 |
+
self.DWT = WaveletTransform3D(inverse=False)
|
| 451 |
+
|
| 452 |
+
def forward(self, x, x_emb=None):
|
| 453 |
+
# embedding branch
|
| 454 |
+
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = self.DWT(x_emb)
|
| 455 |
+
lo_temp = torch.cat([LLL, LHL, HLL, HHL], dim=1)
|
| 456 |
+
hi_temp = torch.cat([LLH, LHH, HLH, HHH], dim=1)
|
| 457 |
+
x_emb = torch.cat([lo_temp, hi_temp], dim=2)
|
| 458 |
+
x_emb = self.conv_emb(x_emb)
|
| 459 |
+
# downsampling branch
|
| 460 |
+
x = self.conv_0(x)
|
| 461 |
+
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = self.DWT(x)
|
| 462 |
+
spatio_lo_coeffs = torch.cat([LLL, LLH], dim=2)
|
| 463 |
+
spatio_hi_coeffs = torch.cat([LHL, LHH, HLL, HLH, HHL, HHH], dim=1)
|
| 464 |
+
x = self.conv_1(spatio_lo_coeffs)
|
| 465 |
+
return (x + x_emb), spatio_hi_coeffs
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class Wavelet_3D_Reconstruction(nn.Module):
|
| 469 |
+
def __init__(self, in_dim, out_dim, hi_dim):
|
| 470 |
+
super().__init__()
|
| 471 |
+
self.conv_0 = nn.Sequential(nn.Conv3d(in_dim, out_dim, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),),
|
| 472 |
+
nn.BatchNorm3d(out_dim),
|
| 473 |
+
nn.GELU(),)
|
| 474 |
+
|
| 475 |
+
self.conv_hi = nn.Sequential(nn.Conv3d(int(hi_dim * 6), int(out_dim * 6), kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=6),
|
| 476 |
+
nn.BatchNorm3d(out_dim * 6),
|
| 477 |
+
nn.GELU(),)
|
| 478 |
+
|
| 479 |
+
self.IDWT = WaveletTransform3D(inverse=True)
|
| 480 |
+
|
| 481 |
+
def forward(self, x, skip_hi=None):
|
| 482 |
+
LLL, LLH = torch.chunk(self.conv_0(x), chunks=2, dim=2)
|
| 483 |
+
LHL, LHH, HLL, HLH, HHL, HHH = torch.chunk(self.conv_hi(skip_hi), chunks=6, dim=1)
|
| 484 |
+
x = self.IDWT([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH])
|
| 485 |
+
return x
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class WaST_level1(nn.Module):
|
| 489 |
+
def __init__(self, in_shape, encoder_dim, block_list=[2, 2, 2], drop_path_rate=0.1, mlp_ratio=4., **kwargs):
|
| 490 |
+
super().__init__()
|
| 491 |
+
frame, in_dim, H, W = in_shape
|
| 492 |
+
self.block_list = block_list
|
| 493 |
+
dp_list = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.block_list))]
|
| 494 |
+
indexes = list(accumulate(block_list))
|
| 495 |
+
dp_list = [dp_list[start:end] for start, end in zip([0] + indexes, indexes)]
|
| 496 |
+
|
| 497 |
+
self.conv_in = nn.Sequential(
|
| 498 |
+
nn.Conv3d(
|
| 499 |
+
in_dim,
|
| 500 |
+
encoder_dim,
|
| 501 |
+
kernel_size=(3, 3, 3),
|
| 502 |
+
stride=(1, 1, 1),
|
| 503 |
+
padding=(1, 1, 1),
|
| 504 |
+
),
|
| 505 |
+
nn.BatchNorm3d(encoder_dim),
|
| 506 |
+
nn.GELU()
|
| 507 |
+
)
|
| 508 |
+
self.translator1 = TF_AwareBlocks(dim=encoder_dim * frame, num_blocks=block_list[0], drop_path=dp_list[0], mlp_ratio=mlp_ratio, large_kernel=51, small_kernel=5)
|
| 509 |
+
|
| 510 |
+
self.wavelet_embed1 = Wavelet_3D_Embedding(in_dim=encoder_dim, out_dim=encoder_dim * 2, emb_dim=in_dim) # wavelet_recon2: hi_dim = in_dim
|
| 511 |
+
|
| 512 |
+
self.bottleneck_translator = TF_AwareBlocks(dim=encoder_dim * 2 * frame, num_blocks=block_list[1], drop_path=dp_list[1], use_bottleneck=True, mlp_ratio=mlp_ratio, large_kernel=21, small_kernel=5)
|
| 513 |
+
|
| 514 |
+
self.wavelet_recon1 = Wavelet_3D_Reconstruction(in_dim=encoder_dim * 2, out_dim=encoder_dim, hi_dim=encoder_dim)
|
| 515 |
+
self.translator2 = TF_AwareBlocks(dim=encoder_dim * frame, num_blocks=block_list[2], drop_path=dp_list[2], use_hid=True, mlp_ratio=mlp_ratio, large_kernel=51, small_kernel=5)
|
| 516 |
+
|
| 517 |
+
self.conv_out = nn.Sequential(
|
| 518 |
+
nn.BatchNorm3d(encoder_dim),
|
| 519 |
+
nn.GELU(),
|
| 520 |
+
nn.Conv3d(
|
| 521 |
+
encoder_dim,
|
| 522 |
+
in_dim,
|
| 523 |
+
kernel_size=(3, 3, 3),
|
| 524 |
+
stride=(1, 1, 1),
|
| 525 |
+
padding=(1, 1, 1))
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
def update_drop_path(self, drop_path_rate):
|
| 529 |
+
dp_list = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.block_list))]
|
| 530 |
+
indexes = list(accumulate(self.block_list))
|
| 531 |
+
dp_lists = [dp_list[start:end] for start, end in zip([0] + indexes, indexes)]
|
| 532 |
+
dp_apply_blocks = [self.translator1.blocks, self.bottleneck_translator.blocks, self.translator2.blocks]
|
| 533 |
+
for translators, dp_list_translators in zip(dp_apply_blocks, dp_lists):
|
| 534 |
+
for translator, dp_list_translator in zip(translators, dp_list_translators):
|
| 535 |
+
translator.drop_path.drop_prob = dp_list_translator
|
| 536 |
+
|
| 537 |
+
def forward(self, x):
|
| 538 |
+
x = rearrange(x, 'b t c h w -> b c t h w')
|
| 539 |
+
|
| 540 |
+
ori_img = x
|
| 541 |
+
x = self.conv_in(x)
|
| 542 |
+
|
| 543 |
+
x, tskip1 = self.translator1(x)
|
| 544 |
+
x, skip1 = self.wavelet_embed1(x, x_emb=ori_img)
|
| 545 |
+
|
| 546 |
+
x = self.bottleneck_translator(x)
|
| 547 |
+
|
| 548 |
+
x = self.wavelet_recon1(x, skip1)
|
| 549 |
+
x = self.translator2(x, tskip1)
|
| 550 |
+
|
| 551 |
+
x = self.conv_out(x)
|
| 552 |
+
|
| 553 |
+
x = rearrange(x, 'b c t h w -> b t c h w')
|
| 554 |
+
return x
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
if __name__ == "__main__":
|
| 560 |
+
from fvcore.nn import FlopCountAnalysis, flop_count_table
|
| 561 |
+
# import os
|
| 562 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
model = WaST_level1(in_shape=(4, 2, 32, 32), encoder_dim=20, block_list=[2, 8, 2]).cuda()
|
| 566 |
+
print(model)
|
| 567 |
+
dummy_tensor = torch.rand(1, 4, 2, 32, 32).cuda()
|
| 568 |
+
output = model(dummy_tensor)
|
| 569 |
+
print(f"input shape is {dummy_tensor.shape}, output shape is {output.shape}...")
|
| 570 |
+
flops = FlopCountAnalysis(model, dummy_tensor)
|
| 571 |
+
print(flop_count_table(flops))
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
|