Marek Bukowicki commited on
Commit
5f02d3e
·
1 Parent(s): 73942d1

fix old models loading

Browse files
Files changed (1) hide show
  1. shimnet/models.py +47 -8
shimnet/models.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
 
3
  # class ConvEncoder(torch.nn.Module):
4
  # def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
@@ -56,7 +57,7 @@ def get_activation(activation_name: str) -> torch.nn.Module:
56
 
57
 
58
  class ConvEncoder(torch.nn.Module):
59
- def __init__(self, hidden_dim=64, output_dim=None, input_dim=1, dropout=0, kernel_size=7, activation="relu"):
60
  super().__init__()
61
  if output_dim is None:
62
  output_dim = hidden_dim
@@ -70,17 +71,20 @@ class ConvEncoder(torch.nn.Module):
70
  torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size),
71
  get_activation(activation),
72
  torch.nn.Dropout(dropout),
73
- torch.nn.Conv1d(hidden_dim, output_dim, kernel_size),
74
- get_activation(activation),
75
- torch.nn.Dropout(dropout),
76
  ]
77
- self.net = torch.nn.Sequential(*layers)
 
 
 
 
 
78
 
79
  def forward(self, feature):
80
  return self.net(feature)
81
 
82
  class ConvDecoder(torch.nn.Module):
83
- def __init__(self, input_dim=None, hidden_dim=64, output_dim=1, dropout=0, kernel_size=7, activation="relu", last_bias=True, last_activation=True):
84
  super().__init__()
85
  if input_dim is None:
86
  input_dim = hidden_dim
@@ -159,15 +163,50 @@ class ShimNetWithSCRF(torch.nn.Module):
159
  decoder_hidden_dims=64
160
  ):
161
  super().__init__()
162
- self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout)
163
  self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim))
164
  torch.nn.init.xavier_normal_(self.query)
165
 
166
- self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims)
167
 
168
  self.rensponse_length = rensponse_length
169
  self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def forward(self, feature): #(samples, 1, 2048)
172
  feature = self.encoder(feature) #(samples, 64, 2042)
173
  energy = self.query @ feature #(samples, 1, 2024)
 
1
  import torch
2
+ from collections import OrderedDict
3
 
4
  # class ConvEncoder(torch.nn.Module):
5
  # def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
 
57
 
58
 
59
  class ConvEncoder(torch.nn.Module):
60
+ def __init__(self, hidden_dim=64, output_dim=None, input_dim=1, dropout=0, kernel_size=7, activation="relu", last_activation=True):
61
  super().__init__()
62
  if output_dim is None:
63
  output_dim = hidden_dim
 
71
  torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size),
72
  get_activation(activation),
73
  torch.nn.Dropout(dropout),
74
+ torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
 
 
75
  ]
76
+ if last_activation:
77
+ layers.append(
78
+ get_activation(activation)
79
+ )
80
+ layers.append(torch.nn.Dropout(dropout))
81
+ self.net = torch.nn.Sequential(*layers)
82
 
83
  def forward(self, feature):
84
  return self.net(feature)
85
 
86
  class ConvDecoder(torch.nn.Module):
87
+ def __init__(self, input_dim=None, hidden_dim=64, output_dim=1, dropout=0, kernel_size=7, activation="relu", last_bias=True, last_activation=False):
88
  super().__init__()
89
  if input_dim is None:
90
  input_dim = hidden_dim
 
163
  decoder_hidden_dims=64
164
  ):
165
  super().__init__()
166
+ self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout, last_activation=False)
167
  self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim))
168
  torch.nn.init.xavier_normal_(self.query)
169
 
170
+ self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims, last_activation=False)
171
 
172
  self.rensponse_length = rensponse_length
173
  self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims)
174
+
175
+ self.EncoderLegacyNameMapping = {
176
+ "conv4": "net.0",
177
+ "conv3": "net.3",
178
+ "conv2": "net.6",
179
+ "conv1": "net.9",
180
+ }
181
+ self.DecoderLegacyNameMapping = {
182
+ "convTranspose1": "net.0",
183
+ "convTranspose2": "net.3",
184
+ "convTranspose3": "net.6",
185
+ "convTranspose4": "net.9",
186
+ }
187
+
188
 
189
+ def load_state_dict(self, state_dict, strict=True):
190
+ new_state_dict = OrderedDict()
191
+ for k, v in state_dict.items():
192
+ k_splitted = k.split(".")
193
+ if k_splitted[0] == "encoder":
194
+ if k_splitted[1] in self.EncoderLegacyNameMapping:
195
+ k_splitted[1] = self.EncoderLegacyNameMapping[k_splitted[1]]
196
+ new_key = ".".join(k_splitted)
197
+ else:
198
+ new_key = k
199
+ elif k_splitted[0] == "decoder":
200
+ if k_splitted[1] in self.DecoderLegacyNameMapping:
201
+ k_splitted[1] = self.DecoderLegacyNameMapping[k_splitted[1]]
202
+ new_key = ".".join(k_splitted)
203
+ else:
204
+ new_key = k
205
+ else:
206
+ new_key = k
207
+ new_state_dict[new_key] = v
208
+ super().load_state_dict(new_state_dict, strict=strict)
209
+
210
  def forward(self, feature): #(samples, 1, 2048)
211
  feature = self.encoder(feature) #(samples, 64, 2042)
212
  energy = self.query @ feature #(samples, 1, 2024)