Marek Bukowicki commited on
Commit
0120aad
·
1 Parent(s): a86e7e6

add new models from feature/models

Browse files
Files changed (1) hide show
  1. shimnet/models.py +273 -32
shimnet/models.py CHANGED
@@ -1,45 +1,138 @@
1
  import torch
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  class ConvEncoder(torch.nn.Module):
4
- def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
5
  super().__init__()
6
  if output_dim is None:
7
  output_dim = hidden_dim
8
- self.conv4 = torch.nn.Conv1d(1, hidden_dim, kernel_size)
9
- self.conv3 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
10
- self.conv2 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
11
- self.conv1 = torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
12
- self.dropout = torch.nn.Dropout(dropout)
13
-
14
- def forward(self, feature): #(samples, 1, 2048)
15
- feature = self.dropout(self.conv4(feature)) #(samples, 64, 2042)
16
- feature = feature.relu()
17
- feature = self.dropout(self.conv3(feature)) #(samples, 64, 2036)
18
- feature = feature.relu()
19
- feature = self.dropout(self.conv2(feature)) #(samples, 64, 2030)
20
- feature = feature.relu()
21
- feature = self.dropout(self.conv1(feature)) #(samples, 64, 2024)
22
- return feature
 
 
 
23
 
24
  class ConvDecoder(torch.nn.Module):
25
- def __init__(self, input_dim=None, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
26
  super().__init__()
27
- if output_dim is None:
28
- output_dim = hidden_dim
29
- self.convTranspose1 = torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size)
30
- self.convTranspose2 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
31
- self.convTranspose3 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
32
- self.convTranspose4 = torch.nn.ConvTranspose1d(hidden_dim, 1, kernel_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def forward(self, feature): #(samples, 1, 2048)
35
- feature = self.convTranspose1(feature) #(samples, 64, 2030)
36
- feature = feature.relu()
37
- feature = self.convTranspose2(feature) #(samples, 64, 2036)
38
- feature = feature.relu()
39
- feature = self.convTranspose3(feature) #(samples, 64, 2042)
40
- feature = feature.relu()
41
- feature = self.convTranspose4(feature)
42
- return feature
43
 
44
  class ResponseHead(torch.nn.Module):
45
  def __init__(self, input_dim, output_length, hidden_dims=[128]):
@@ -93,6 +186,89 @@ class ShimNetWithSCRF(torch.nn.Module):
93
  'attention': weight.squeeze(1)
94
  }
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  class Predictor:
97
  def __init__(self, model=None, weights_file=None):
98
  self.model = model
@@ -103,3 +279,68 @@ class Predictor:
103
  with torch.no_grad():
104
  msf_frq = self.model(nsf_frq[None, None])["denoised"]
105
  return msf_frq[0, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
 
3
+ # class ConvEncoder(torch.nn.Module):
4
+ # def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
5
+ # super().__init__()
6
+ # if output_dim is None:
7
+ # output_dim = hidden_dim
8
+ # self.conv4 = torch.nn.Conv1d(1, hidden_dim, kernel_size)
9
+ # self.conv3 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
10
+ # self.conv2 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
11
+ # self.conv1 = torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
12
+ # self.dropout = torch.nn.Dropout(dropout)
13
+
14
+ # def forward(self, feature): #(samples, 1, 2048)
15
+ # feature = self.dropout(self.conv4(feature)) #(samples, 64, 2042)
16
+ # feature = feature.relu()
17
+ # feature = self.dropout(self.conv3(feature)) #(samples, 64, 2036)
18
+ # feature = feature.relu()
19
+ # feature = self.dropout(self.conv2(feature)) #(samples, 64, 2030)
20
+ # feature = feature.relu()
21
+ # feature = self.dropout(self.conv1(feature)) #(samples, 64, 2024)
22
+ # return feature
23
+
24
+ # class ConvDecoder(torch.nn.Module):
25
+ # def __init__(self, input_dim=None, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
26
+ # super().__init__()
27
+ # if output_dim is None:
28
+ # output_dim = hidden_dim
29
+ # self.convTranspose1 = torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size)
30
+ # self.convTranspose2 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
31
+ # self.convTranspose3 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
32
+ # self.convTranspose4 = torch.nn.ConvTranspose1d(hidden_dim, 1, kernel_size)
33
+
34
+ # def forward(self, feature): #(samples, 1, 2048)
35
+ # feature = self.convTranspose1(feature) #(samples, 64, 2030)
36
+ # feature = feature.relu()
37
+ # feature = self.convTranspose2(feature) #(samples, 64, 2036)
38
+ # feature = feature.relu()
39
+ # feature = self.convTranspose3(feature) #(samples, 64, 2042)
40
+ # feature = feature.relu()
41
+ # feature = self.convTranspose4(feature)
42
+ # return feature
43
+ def get_activation(activation_name: str) -> torch.nn.Module:
44
+ if activation_name == "relu":
45
+ return torch.nn.ReLU()
46
+ elif activation_name == "gelu":
47
+ return torch.nn.GELU()
48
+ elif activation_name == "leaky_relu":
49
+ return torch.nn.LeakyReLU()
50
+ elif activation_name == "tanh":
51
+ return torch.nn.Tanh()
52
+ elif activation_name == "sigmoid":
53
+ return torch.nn.Sigmoid()
54
+ else:
55
+ raise ValueError(f"Unsupported activation function: {activation_name}")
56
+
57
+
58
  class ConvEncoder(torch.nn.Module):
59
+ def __init__(self, hidden_dim=64, output_dim=None, input_dim=1, dropout=0, kernel_size=7, activation="relu"):
60
  super().__init__()
61
  if output_dim is None:
62
  output_dim = hidden_dim
63
+ layers = [
64
+ torch.nn.Conv1d(input_dim, hidden_dim, kernel_size),
65
+ get_activation(activation),
66
+ torch.nn.Dropout(dropout),
67
+ torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size),
68
+ get_activation(activation),
69
+ torch.nn.Dropout(dropout),
70
+ torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size),
71
+ get_activation(activation),
72
+ torch.nn.Dropout(dropout),
73
+ torch.nn.Conv1d(hidden_dim, output_dim, kernel_size),
74
+ get_activation(activation),
75
+ torch.nn.Dropout(dropout),
76
+ ]
77
+ self.net = torch.nn.Sequential(*layers)
78
+
79
+ def forward(self, feature):
80
+ return self.net(feature)
81
 
82
  class ConvDecoder(torch.nn.Module):
83
+ def __init__(self, input_dim=None, hidden_dim=64, output_dim=1, dropout=0, kernel_size=7, activation="relu", last_bias=True, last_activation=True):
84
  super().__init__()
85
+ if input_dim is None:
86
+ input_dim = hidden_dim
87
+ layers = [
88
+ torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size),
89
+ get_activation(activation),
90
+ torch.nn.Dropout(dropout),
91
+ torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size),
92
+ get_activation(activation),
93
+ torch.nn.Dropout(dropout),
94
+ torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size),
95
+ get_activation(activation),
96
+ torch.nn.Dropout(dropout),
97
+ torch.nn.ConvTranspose1d(hidden_dim, output_dim, kernel_size, bias=last_bias),
98
+ ]
99
+ if last_activation:
100
+ layers.append(get_activation(activation))
101
+ layers.append(torch.nn.Dropout(dropout))
102
+ self.net = torch.nn.Sequential(*layers)
103
+
104
+ def forward(self, feature):
105
+ return self.net(feature)
106
+
107
+ class ConvMLP(torch.nn.Module):
108
+ def __init__(self, input_dim, output_dim, hidden_dims=[128, 64], activation="relu"):
109
+ super().__init__()
110
+ mlp_dims = [input_dim] + hidden_dims + [output_dim]
111
+ mlp_layers = [torch.nn.Conv1d(mlp_dims[0], mlp_dims[1], kernel_size=1)]
112
+ for dims_in, dims_out in zip(mlp_dims[1:-1], mlp_dims[2:]):
113
+ mlp_layers.extend([
114
+ get_activation(activation),
115
+ torch.nn.Conv1d(dims_in, dims_out, kernel_size=1)
116
+ ])
117
+ self.mlp = torch.nn.Sequential(*mlp_layers)
118
+
119
+ def forward(self, x):
120
+ return self.mlp(x)
121
+
122
+ class MLP(torch.nn.Module):
123
+ def __init__(self, input_dim, output_dim, hidden_dims=[128, 64], activation="relu"):
124
+ super().__init__()
125
+ mlp_dims = [input_dim] + hidden_dims + [output_dim]
126
+ mlp_layers = [torch.nn.Linear(mlp_dims[0], mlp_dims[1])]
127
+ for dims_in, dims_out in zip(mlp_dims[1:-1], mlp_dims[2:]):
128
+ mlp_layers.extend([
129
+ get_activation(activation),
130
+ torch.nn.Linear(dims_in, dims_out)
131
+ ])
132
+ self.mlp = torch.nn.Sequential(*mlp_layers)
133
 
134
+ def forward(self, x):
135
+ return self.mlp(x)
 
 
 
 
 
 
 
136
 
137
  class ResponseHead(torch.nn.Module):
138
  def __init__(self, input_dim, output_length, hidden_dims=[128]):
 
186
  'attention': weight.squeeze(1)
187
  }
188
 
189
+ class KVAttention(torch.nn.Module):
190
+ """attention with learnable query"""
191
+ def __init__(self,
192
+ kv_dim =64,
193
+ num_heads=4,
194
+ k_processor = None,
195
+ v_processor = None,
196
+ ):
197
+ super().__init__()
198
+ if k_processor is None:
199
+ k_processor = torch.nn.Identity()
200
+ if v_processor is None:
201
+ v_processor = torch.nn.Identity()
202
+ self.k_processor = k_processor
203
+ self.v_processor = v_processor
204
+
205
+ self.kv_dim = kv_dim
206
+ self.num_heads = num_heads
207
+ self.query = torch.nn.Parameter(torch.empty(1, num_heads, kv_dim))
208
+ torch.nn.init.xavier_normal_(self.query)
209
+
210
+ def forward(self, feature): # (samples, input_dim, seq_len)
211
+ batch_size = feature.shape[0]
212
+ seq_len = feature.shape[-1]
213
+ keys = self.k_processor(feature)
214
+ values = feature
215
+
216
+ # Reshape for multi-head attention
217
+ keys = keys.view(batch_size, self.num_heads, self.kv_dim, seq_len) #(samples, num_heads, kv_dim, seq_len)
218
+
219
+ # Multi-head attention computation
220
+ queries = self.query.expand(batch_size, -1, -1) #(samples, num_heads, kv_dim)
221
+ energy = torch.einsum('bhd,bhdl->bhl', queries, keys) #(samples, num_heads, seq_len)
222
+ weight = torch.nn.functional.softmax(energy, dim=2) #(samples, num_heads, seq_len)
223
+
224
+ # Apply attention weights
225
+ global_features = torch.einsum('bhl,bhdl->bhd', weight, feature.view(batch_size, self.num_heads, -1, seq_len)) #(samples, (num_heads* head_dim))
226
+ global_features = global_features.reshape(batch_size, -1) #(samples, (num_heads* head_dim))
227
+
228
+ # process values if needed
229
+ global_features = self.v_processor(global_features) # (samples, input_dim)
230
+ # global_features = global_features.reshape(batch_size, -1, 1)
231
+
232
+ return global_features, weight
233
+
234
+
235
+ class ShimnetModular(torch.nn.Module):
236
+ def __init__(self,
237
+ encoder,
238
+ decoder,
239
+ response_head,
240
+ attention_module,
241
+ local_feature_processor,
242
+ global_feature_processor
243
+ ):
244
+ super().__init__()
245
+ self.encoder = encoder
246
+ self.attention_module = attention_module
247
+ self.decoder = decoder
248
+ self.response_head = response_head
249
+ self.local_feature_processor = local_feature_processor
250
+ self.global_feature_processor = global_feature_processor
251
+
252
+ def forward(self, feature): #(samples, 1, seq_len_in)
253
+ feature = self.encoder(feature) #(samples, encoder_features_dim, seq_len) # seq_len != seq_len_in
254
+ local_features = self.local_feature_processor(feature) #(samples, local_features_dim, seq_len)
255
+
256
+ global_features, weight = self.attention_module(feature) #(samples, global_features_hidden_dim, 1), (samples, num_heads, seq_len)
257
+
258
+ response = self.response_head(global_features.squeeze(-1)) # (samples, response_length)
259
+
260
+ global_features_for_decoding = self.global_feature_processor(global_features).unsqueeze(-1) #(samples, global_features_dim, 1)
261
+
262
+ local_features, global_features_for_decoding = torch.broadcast_tensors(local_features, global_features_for_decoding) #(samples, local_features_dim, seq_len), (samples, global_features_dim, seq_len)
263
+ feature = torch.cat([local_features, global_features_for_decoding], 1) #(samples, local_features_dim + global_features_dim, seq_len)
264
+ denoised_spectrum = self.decoder(feature) #(samples, 1, seq_len_in)
265
+
266
+ return {
267
+ 'denoised': denoised_spectrum,
268
+ 'response': response,
269
+ 'attention': weight.sum(1) # (samples, seq_len)
270
+ }
271
+
272
  class Predictor:
273
  def __init__(self, model=None, weights_file=None):
274
  self.model = model
 
279
  with torch.no_grad():
280
  msf_frq = self.model(nsf_frq[None, None])["denoised"]
281
  return msf_frq[0, 0]
282
+
283
+ if __name__ == "__main__":
284
+ encoder_hidden_dims = 64
285
+ encoder_dropout = 0
286
+ encoder_features_dim = 128
287
+
288
+ local_features_dim = 64
289
+
290
+ attention_kv_dim = 32
291
+ attention_num_heads = 8
292
+ global_features_hidden_dim = 256
293
+
294
+ global_features_dim = 64
295
+ response_length = 81
296
+
297
+
298
+ encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=encoder_features_dim, dropout=encoder_dropout)
299
+ local_feature_processor = ConvMLP(encoder_features_dim, local_features_dim, hidden_dims=[256, 128])
300
+ attention = KVAttention(
301
+ kv_dim=attention_kv_dim, num_heads=attention_num_heads,
302
+ k_processor = ConvMLP(encoder_features_dim, attention_kv_dim*attention_num_heads, hidden_dims=[512, 256]),
303
+ v_processor = MLP(encoder_features_dim, global_features_hidden_dim, hidden_dims=[512, 256]),
304
+ )
305
+ global_feature_processor = MLP(global_features_hidden_dim, global_features_dim, hidden_dims=[512, 256])
306
+ response_head = MLP(global_features_hidden_dim, response_length, hidden_dims=[512, 256])
307
+
308
+ decoder = ConvDecoder(input_dim=local_features_dim + global_features_dim, hidden_dim=64)
309
+
310
+
311
+ ### step by step
312
+ inputs = torch.randn(2, 1, 2048)
313
+
314
+ feature = encoder(inputs) #(samples, encoder_features_dim, seq_len) # seq_len != seq_len_in
315
+ print(f"Encoder output shape: {feature.shape}")
316
+
317
+ local_features = local_feature_processor(feature) #(samples, local_features_dim, seq_len)
318
+ print(f"Local features shape: {local_features.shape}")
319
+
320
+ global_features, weight = attention(feature) #(samples, global_features_hidden_dim, 1), (samples, num_heads, seq_len)
321
+ print(f"Global features shape: {global_features.shape}")
322
+ print(f"Attention weights shape: {weight.shape}")
323
+
324
+ response = response_head(global_features) # (samples, response_length)
325
+ print(f"Response shape: {response.shape}")
326
+
327
+ global_features_for_decoding = global_feature_processor(global_features).unsqueeze(-1) #(samples, global_features_dim, 1)
328
+
329
+ local_features, global_features_for_decoding = torch.broadcast_tensors(local_features, global_features_for_decoding) #(samples, local_features_dim, seq_len), (samples, global_features_dim, seq_len)
330
+ feature = torch.cat([local_features, global_features_for_decoding], 1) #(samples, local_features_dim + global_features_dim, seq_len)
331
+ denoised_spectrum = decoder(feature)
332
+
333
+ print("="*80)
334
+ ### assemble model
335
+
336
+ model = ShimnetModular(
337
+ encoder=encoder,
338
+ decoder=decoder,
339
+ response_head=response_head,
340
+ attention_module=attention,
341
+ local_feature_processor=local_feature_processor,
342
+ global_feature_processor=global_feature_processor
343
+ )
344
+
345
+ for k, v in model(inputs).items():
346
+ print(f"{k}: {v.shape}")