jimmy60504 commited on
Commit
7781d84
·
1 Parent(s): 1e575c1

docs: add full model implementation with CNN, MLP, and Transformer components

Browse files
Files changed (1) hide show
  1. model.py +375 -0
model.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from loguru import logger
6
+
7
+ # GPU/CPU 設定
8
+ if torch.cuda.is_available():
9
+ device = torch.device("cuda")
10
+ logger.info("使用 GPU")
11
+ elif torch.mps.is_available():
12
+ device = torch.device("mps")
13
+ logger.info("使用 Apple MPS")
14
+ else:
15
+ device = torch.device("cpu")
16
+ logger.info("使用 CPU")
17
+
18
+
19
+ class LambdaLayer(nn.Module):
20
+ def __init__(self, lambd, eps=1e-4):
21
+ super(LambdaLayer, self).__init__()
22
+ self.lambd = lambd
23
+ self.eps = eps
24
+
25
+ def forward(self, x):
26
+ return self.lambd(x) + self.eps
27
+
28
+
29
+ class MLP(nn.Module):
30
+ def __init__(
31
+ self,
32
+ input_shape,
33
+ dims=(500, 300, 200, 150),
34
+ activation=nn.ReLU(),
35
+ last_activation=None,
36
+ ):
37
+ super(MLP, self).__init__()
38
+ if last_activation is None:
39
+ last_activation = activation
40
+ self.dims = dims
41
+ self.first_fc = nn.Linear(input_shape[0], dims[0])
42
+ self.first_activation = activation
43
+
44
+ more_hidden = []
45
+ if len(self.dims) > 2:
46
+ for i in range(1, len(self.dims) - 1):
47
+ more_hidden.append(nn.Linear(self.dims[i - 1], self.dims[i]))
48
+ more_hidden.append(nn.ReLU())
49
+
50
+ self.more_hidden = nn.ModuleList(more_hidden)
51
+ self.last_fc = nn.Linear(dims[-2], dims[-1])
52
+ self.last_activation = last_activation
53
+
54
+ def forward(self, x):
55
+ output = self.first_fc(x)
56
+ output = self.first_activation(output)
57
+ if self.more_hidden:
58
+ for layer in self.more_hidden:
59
+ output = layer(output)
60
+ output = self.last_fc(output)
61
+ output = self.last_activation(output)
62
+ return output
63
+
64
+
65
+ class CNN(nn.Module):
66
+ def __init__(
67
+ self,
68
+ input_shape=(-1, 6000, 3),
69
+ activation=nn.ReLU(),
70
+ downsample=1,
71
+ mlp_input=11665,
72
+ mlp_dims=(500, 300, 200, 150),
73
+ eps=1e-8,
74
+ ):
75
+ super(CNN, self).__init__()
76
+ self.input_shape = input_shape
77
+ self.activation = activation
78
+ self.downsample = downsample
79
+ self.mlp_input = mlp_input
80
+ self.mlp_dims = mlp_dims
81
+ self.eps = eps
82
+
83
+ self.lambda_layer_1 = LambdaLayer(
84
+ lambda t: t
85
+ / (
86
+ torch.max(
87
+ torch.max(torch.abs(t), dim=1, keepdim=True).values,
88
+ dim=2,
89
+ keepdim=True,
90
+ ).values
91
+ + self.eps
92
+ )
93
+ )
94
+ self.unsqueeze_layer1 = LambdaLayer(lambda t: torch.unsqueeze(t, dim=1))
95
+ self.lambda_layer_2 = LambdaLayer(
96
+ lambda t: torch.log(
97
+ torch.max(torch.max(torch.abs(t), dim=1).values, dim=1).values
98
+ + self.eps
99
+ )
100
+ / 100
101
+ )
102
+ self.unsqueeze_layer2 = LambdaLayer(lambda t: torch.unsqueeze(t, dim=1))
103
+ self.conv2d1 = nn.Sequential(
104
+ nn.Conv2d(1, 8, kernel_size=(1, downsample), stride=(1, downsample)),
105
+ nn.ReLU(),
106
+ )
107
+ self.conv2d2 = nn.Sequential(
108
+ nn.Conv2d(8, 32, kernel_size=(16, 3), stride=(1, 3)), nn.ReLU()
109
+ )
110
+ self.conv1d1 = nn.Sequential(nn.Conv1d(32, 64, kernel_size=16), nn.ReLU())
111
+ self.maxpooling = nn.MaxPool1d(2)
112
+ self.conv1d2 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=16), nn.ReLU())
113
+ self.conv1d3 = nn.Sequential(nn.Conv1d(128, 32, kernel_size=8), nn.ReLU())
114
+ self.conv1d4 = nn.Sequential(nn.Conv1d(32, 32, kernel_size=8), nn.ReLU())
115
+ self.conv1d5 = nn.Sequential(nn.Conv1d(32, 16, kernel_size=4), nn.ReLU())
116
+ self.mlp = MLP((self.mlp_input,), dims=self.mlp_dims)
117
+
118
+ def forward(self, x):
119
+ output = self.lambda_layer_1(x)
120
+ output = self.unsqueeze_layer1(output)
121
+ scale = self.lambda_layer_2(x)
122
+ scale = self.unsqueeze_layer2(scale)
123
+ output = self.conv2d1(output)
124
+ output = self.conv2d2(output)
125
+ output = torch.squeeze(output, dim=-1)
126
+ output = self.conv1d1(output)
127
+ output = self.maxpooling(output)
128
+ output = self.conv1d2(output)
129
+ output = self.maxpooling(output)
130
+ output = self.conv1d3(output)
131
+ output = self.maxpooling(output)
132
+ output = self.conv1d4(output)
133
+ output = self.conv1d5(output)
134
+ output = torch.flatten(output, start_dim=1)
135
+ output = torch.cat((output, scale), dim=1)
136
+ output = self.mlp(output)
137
+ return output
138
+
139
+
140
+ class PositionEmbeddingVs30(nn.Module):
141
+ def __init__(
142
+ self, wavelengths=((5, 30), (110, 123), (0.01, 5000), (100, 1600)), emb_dim=500
143
+ ):
144
+ super(PositionEmbeddingVs30, self).__init__()
145
+ self.wavelengths = wavelengths
146
+ self.emb_dim = emb_dim
147
+
148
+ min_lat, max_lat = wavelengths[0]
149
+ min_lon, max_lon = wavelengths[1]
150
+ min_depth, max_depth = wavelengths[2]
151
+ min_vs30, max_vs30 = wavelengths[3]
152
+
153
+ assert emb_dim % 10 == 0
154
+ lat_dim = emb_dim // 5
155
+ lon_dim = emb_dim // 5
156
+ depth_dim = emb_dim // 10
157
+ vs30_dim = emb_dim // 10
158
+
159
+ self.lat_coeff = (
160
+ 2
161
+ * np.pi
162
+ * 1.0
163
+ / min_lat
164
+ * ((min_lat / max_lat) ** (np.arange(lat_dim) / lat_dim))
165
+ )
166
+ self.lon_coeff = (
167
+ 2
168
+ * np.pi
169
+ * 1.0
170
+ / min_lon
171
+ * ((min_lon / max_lon) ** (np.arange(lon_dim) / lon_dim))
172
+ )
173
+ self.depth_coeff = (
174
+ 2
175
+ * np.pi
176
+ * 1.0
177
+ / min_depth
178
+ * ((min_depth / max_depth) ** (np.arange(depth_dim) / depth_dim))
179
+ )
180
+ self.vs30_coeff = (
181
+ 2
182
+ * np.pi
183
+ * 1.0
184
+ / min_vs30
185
+ * ((min_vs30 / max_vs30) ** (np.arange(vs30_dim) / vs30_dim))
186
+ )
187
+
188
+ lat_sin_mask = np.arange(emb_dim) % 5 == 0
189
+ lat_cos_mask = np.arange(emb_dim) % 5 == 1
190
+ lon_sin_mask = np.arange(emb_dim) % 5 == 2
191
+ lon_cos_mask = np.arange(emb_dim) % 5 == 3
192
+ depth_sin_mask = np.arange(emb_dim) % 10 == 4
193
+ depth_cos_mask = np.arange(emb_dim) % 10 == 9
194
+ vs30_sin_mask = np.arange(emb_dim) % 10 == 5
195
+ vs30_cos_mask = np.arange(emb_dim) % 10 == 8
196
+
197
+ self.mask = np.zeros(emb_dim)
198
+ self.mask[lat_sin_mask] = np.arange(lat_dim)
199
+ self.mask[lat_cos_mask] = lat_dim + np.arange(lat_dim)
200
+ self.mask[lon_sin_mask] = 2 * lat_dim + np.arange(lon_dim)
201
+ self.mask[lon_cos_mask] = 2 * lat_dim + lon_dim + np.arange(lon_dim)
202
+ self.mask[depth_sin_mask] = 2 * lat_dim + 2 * lon_dim + np.arange(depth_dim)
203
+ self.mask[depth_cos_mask] = (
204
+ 2 * lat_dim + 2 * lon_dim + depth_dim + np.arange(depth_dim)
205
+ )
206
+ self.mask[vs30_sin_mask] = (
207
+ 2 * lat_dim + 2 * lon_dim + 2 * depth_dim + np.arange(vs30_dim)
208
+ )
209
+ self.mask[vs30_cos_mask] = (
210
+ 2 * lat_dim + 2 * lon_dim + 2 * depth_dim + vs30_dim + np.arange(vs30_dim)
211
+ )
212
+ self.mask = self.mask.astype("int32")
213
+
214
+ def forward(self, x):
215
+ lat_base = x[:, :, 0:1].to(device) * torch.Tensor(self.lat_coeff).to(device)
216
+ lon_base = x[:, :, 1:2].to(device) * torch.Tensor(self.lon_coeff).to(device)
217
+ depth_base = x[:, :, 2:3].to(device) * torch.Tensor(self.depth_coeff).to(device)
218
+ vs30_base = x[:, :, 3:4] * torch.Tensor(self.vs30_coeff).to(device)
219
+
220
+ output = torch.cat(
221
+ [
222
+ torch.sin(lat_base),
223
+ torch.cos(lat_base),
224
+ torch.sin(lon_base),
225
+ torch.cos(lon_base),
226
+ torch.sin(depth_base),
227
+ torch.cos(depth_base),
228
+ torch.sin(vs30_base),
229
+ torch.cos(vs30_base),
230
+ ],
231
+ dim=-1,
232
+ )
233
+
234
+ maskk = torch.from_numpy(np.array(self.mask)).long()
235
+ index = (
236
+ (maskk.unsqueeze(0).unsqueeze(0))
237
+ .expand(x.shape[0], 1, self.emb_dim)
238
+ .to(device)
239
+ )
240
+ output = torch.gather(output, -1, index).to(device)
241
+ return output
242
+
243
+
244
+ class TransformerEncoder(nn.Module):
245
+ def __init__(
246
+ self,
247
+ d_model=150,
248
+ nhead=10,
249
+ batch_first=True,
250
+ activation="gelu",
251
+ dropout=0.0,
252
+ dim_feedforward=1000,
253
+ ):
254
+ super(TransformerEncoder, self).__init__()
255
+ self.encoder_layer = nn.TransformerEncoderLayer(
256
+ d_model=d_model,
257
+ nhead=nhead,
258
+ batch_first=batch_first,
259
+ activation=activation,
260
+ dropout=dropout,
261
+ dim_feedforward=dim_feedforward,
262
+ ).to(device)
263
+ self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, 6).to(
264
+ device
265
+ )
266
+
267
+ def forward(self, x, src_key_padding_mask=None):
268
+ return self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
269
+
270
+
271
+ class MDN(nn.Module):
272
+ def __init__(self, input_shape=(150,), n_hidden=20, n_gaussians=5):
273
+ super(MDN, self).__init__()
274
+ self.z_h = nn.Sequential(nn.Linear(input_shape[0], n_hidden), nn.Tanh())
275
+ self.z_weight = nn.Linear(n_hidden, n_gaussians)
276
+ self.z_sigma = nn.Linear(n_hidden, n_gaussians)
277
+ self.z_mu = nn.Linear(n_hidden, n_gaussians)
278
+
279
+ def forward(self, x):
280
+ z_h = self.z_h(x)
281
+ weight = nn.functional.softmax(self.z_weight(z_h), -1)
282
+ sigma = torch.exp(self.z_sigma(z_h))
283
+ mu = self.z_mu(z_h)
284
+ return weight, sigma, mu
285
+
286
+
287
+ class FullModel(nn.Module):
288
+ def __init__(
289
+ self,
290
+ model_cnn,
291
+ model_position,
292
+ model_transformer,
293
+ model_mlp,
294
+ model_mdn,
295
+ max_station=25,
296
+ pga_targets=15,
297
+ emb_dim=150,
298
+ data_length=6000,
299
+ ):
300
+ super(FullModel, self).__init__()
301
+ self.data_length = data_length
302
+ self.model_CNN = model_cnn
303
+ self.model_Position = model_position
304
+ self.model_Transformer = model_transformer
305
+ self.model_mlp = model_mlp
306
+ self.model_MDN = model_mdn
307
+ self.max_station = max_station
308
+ self.pga_targets = pga_targets
309
+ self.emb_dim = emb_dim
310
+
311
+ def forward(self, data):
312
+ cnn_output = self.model_CNN(
313
+ torch.DoubleTensor(data["waveform"].reshape(-1, self.data_length, 3))
314
+ .float()
315
+ .to(device)
316
+ )
317
+ cnn_output_reshape = torch.reshape(
318
+ cnn_output, (-1, self.max_station, self.emb_dim)
319
+ )
320
+
321
+ emb_output = self.model_Position(
322
+ torch.DoubleTensor(data["station"].reshape(-1, 1, data["station"].shape[2]))
323
+ .float()
324
+ .to(device)
325
+ )
326
+ emb_output = emb_output.reshape(-1, self.max_station, self.emb_dim)
327
+
328
+ station_pad_mask = data["station"] == 0
329
+ station_pad_mask = torch.all(station_pad_mask, 2)
330
+
331
+ pga_pos_emb_output = self.model_Position(
332
+ torch.DoubleTensor(data["target"].reshape(-1, 1, data["target"].shape[2]))
333
+ .float()
334
+ .to(device)
335
+ )
336
+ pga_pos_emb_output = pga_pos_emb_output.reshape(
337
+ -1, self.pga_targets, self.emb_dim
338
+ )
339
+
340
+ target_pad_mask = torch.ones_like(data["target"], dtype=torch.bool)
341
+ target_pad_mask = torch.all(target_pad_mask, 2)
342
+ pad_mask = torch.cat((station_pad_mask, target_pad_mask), dim=1).to(device)
343
+
344
+ add_pe_cnn_output = torch.add(cnn_output_reshape, emb_output)
345
+ transformer_input = torch.cat((add_pe_cnn_output, pga_pos_emb_output), dim=1)
346
+ transformer_output = self.model_Transformer(transformer_input, pad_mask)
347
+
348
+ mlp_input = transformer_output[:, -self.pga_targets :, :].to(device)
349
+ mlp_output = self.model_mlp(mlp_input)
350
+ weight, sigma, mu = self.model_MDN(mlp_output)
351
+
352
+ return weight, sigma, mu
353
+
354
+
355
+ def get_full_model(model_path):
356
+ emb_dim = 150
357
+ mlp_dims = (150, 100, 50, 30, 10)
358
+ cnn_model = CNN(mlp_input=5665).to(device)
359
+ pos_emb_model = PositionEmbeddingVs30(emb_dim=emb_dim).to(device)
360
+ transformer_model = TransformerEncoder()
361
+ mlp_model = MLP(input_shape=(emb_dim,), dims=mlp_dims).to(device)
362
+ mdn_model = MDN(input_shape=(mlp_dims[-1],)).to(device)
363
+ full_model = FullModel(
364
+ cnn_model,
365
+ pos_emb_model,
366
+ transformer_model,
367
+ mlp_model,
368
+ mdn_model,
369
+ pga_targets=25,
370
+ data_length=3000,
371
+ ).to(device)
372
+ full_model.load_state_dict(
373
+ torch.load(model_path, weights_only=True, map_location=device)
374
+ )
375
+ return full_model