Spaces:
Sleeping
Sleeping
Marek Bukowicki commited on
Commit ·
5f02d3e
1
Parent(s): 73942d1
fix old models loading
Browse files- shimnet/models.py +47 -8
shimnet/models.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
|
| 3 |
# class ConvEncoder(torch.nn.Module):
|
| 4 |
# def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
|
|
@@ -56,7 +57,7 @@ def get_activation(activation_name: str) -> torch.nn.Module:
|
|
| 56 |
|
| 57 |
|
| 58 |
class ConvEncoder(torch.nn.Module):
|
| 59 |
-
def __init__(self, hidden_dim=64, output_dim=None, input_dim=1, dropout=0, kernel_size=7, activation="relu"):
|
| 60 |
super().__init__()
|
| 61 |
if output_dim is None:
|
| 62 |
output_dim = hidden_dim
|
|
@@ -70,17 +71,20 @@ class ConvEncoder(torch.nn.Module):
|
|
| 70 |
torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size),
|
| 71 |
get_activation(activation),
|
| 72 |
torch.nn.Dropout(dropout),
|
| 73 |
-
torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
|
| 74 |
-
get_activation(activation),
|
| 75 |
-
torch.nn.Dropout(dropout),
|
| 76 |
]
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
def forward(self, feature):
|
| 80 |
return self.net(feature)
|
| 81 |
|
| 82 |
class ConvDecoder(torch.nn.Module):
|
| 83 |
-
def __init__(self, input_dim=None, hidden_dim=64, output_dim=1, dropout=0, kernel_size=7, activation="relu", last_bias=True, last_activation=
|
| 84 |
super().__init__()
|
| 85 |
if input_dim is None:
|
| 86 |
input_dim = hidden_dim
|
|
@@ -159,15 +163,50 @@ class ShimNetWithSCRF(torch.nn.Module):
|
|
| 159 |
decoder_hidden_dims=64
|
| 160 |
):
|
| 161 |
super().__init__()
|
| 162 |
-
self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout)
|
| 163 |
self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim))
|
| 164 |
torch.nn.init.xavier_normal_(self.query)
|
| 165 |
|
| 166 |
-
self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims)
|
| 167 |
|
| 168 |
self.rensponse_length = rensponse_length
|
| 169 |
self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
def forward(self, feature): #(samples, 1, 2048)
|
| 172 |
feature = self.encoder(feature) #(samples, 64, 2042)
|
| 173 |
energy = self.query @ feature #(samples, 1, 2024)
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
|
| 4 |
# class ConvEncoder(torch.nn.Module):
|
| 5 |
# def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
class ConvEncoder(torch.nn.Module):
|
| 60 |
+
def __init__(self, hidden_dim=64, output_dim=None, input_dim=1, dropout=0, kernel_size=7, activation="relu", last_activation=True):
|
| 61 |
super().__init__()
|
| 62 |
if output_dim is None:
|
| 63 |
output_dim = hidden_dim
|
|
|
|
| 71 |
torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size),
|
| 72 |
get_activation(activation),
|
| 73 |
torch.nn.Dropout(dropout),
|
| 74 |
+
torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
|
|
|
|
|
|
|
| 75 |
]
|
| 76 |
+
if last_activation:
|
| 77 |
+
layers.append(
|
| 78 |
+
get_activation(activation)
|
| 79 |
+
)
|
| 80 |
+
layers.append(torch.nn.Dropout(dropout))
|
| 81 |
+
self.net = torch.nn.Sequential(*layers)
|
| 82 |
|
| 83 |
def forward(self, feature):
|
| 84 |
return self.net(feature)
|
| 85 |
|
| 86 |
class ConvDecoder(torch.nn.Module):
|
| 87 |
+
def __init__(self, input_dim=None, hidden_dim=64, output_dim=1, dropout=0, kernel_size=7, activation="relu", last_bias=True, last_activation=False):
|
| 88 |
super().__init__()
|
| 89 |
if input_dim is None:
|
| 90 |
input_dim = hidden_dim
|
|
|
|
| 163 |
decoder_hidden_dims=64
|
| 164 |
):
|
| 165 |
super().__init__()
|
| 166 |
+
self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout, last_activation=False)
|
| 167 |
self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim))
|
| 168 |
torch.nn.init.xavier_normal_(self.query)
|
| 169 |
|
| 170 |
+
self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims, last_activation=False)
|
| 171 |
|
| 172 |
self.rensponse_length = rensponse_length
|
| 173 |
self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims)
|
| 174 |
+
|
| 175 |
+
self.EncoderLegacyNameMapping = {
|
| 176 |
+
"conv4": "net.0",
|
| 177 |
+
"conv3": "net.3",
|
| 178 |
+
"conv2": "net.6",
|
| 179 |
+
"conv1": "net.9",
|
| 180 |
+
}
|
| 181 |
+
self.DecoderLegacyNameMapping = {
|
| 182 |
+
"convTranspose1": "net.0",
|
| 183 |
+
"convTranspose2": "net.3",
|
| 184 |
+
"convTranspose3": "net.6",
|
| 185 |
+
"convTranspose4": "net.9",
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
|
| 189 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 190 |
+
new_state_dict = OrderedDict()
|
| 191 |
+
for k, v in state_dict.items():
|
| 192 |
+
k_splitted = k.split(".")
|
| 193 |
+
if k_splitted[0] == "encoder":
|
| 194 |
+
if k_splitted[1] in self.EncoderLegacyNameMapping:
|
| 195 |
+
k_splitted[1] = self.EncoderLegacyNameMapping[k_splitted[1]]
|
| 196 |
+
new_key = ".".join(k_splitted)
|
| 197 |
+
else:
|
| 198 |
+
new_key = k
|
| 199 |
+
elif k_splitted[0] == "decoder":
|
| 200 |
+
if k_splitted[1] in self.DecoderLegacyNameMapping:
|
| 201 |
+
k_splitted[1] = self.DecoderLegacyNameMapping[k_splitted[1]]
|
| 202 |
+
new_key = ".".join(k_splitted)
|
| 203 |
+
else:
|
| 204 |
+
new_key = k
|
| 205 |
+
else:
|
| 206 |
+
new_key = k
|
| 207 |
+
new_state_dict[new_key] = v
|
| 208 |
+
super().load_state_dict(new_state_dict, strict=strict)
|
| 209 |
+
|
| 210 |
def forward(self, feature): #(samples, 1, 2048)
|
| 211 |
feature = self.encoder(feature) #(samples, 64, 2042)
|
| 212 |
energy = self.query @ feature #(samples, 1, 2024)
|