jimmy60504 commited on
Commit
71a525a
·
1 Parent(s): 4cc2090

add requirements and initial setup for TTSAM intensity prediction system

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. .gitignore +0 -0
  3. app.py +556 -4
  4. requirements.txt +13 -0
  5. station/eew_target.csv +48 -0
  6. station/site_info.txt +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mseed filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
File without changes
app.py CHANGED
@@ -1,7 +1,559 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from obspy import read
5
+ from datasets import load_dataset
6
+ import xarray as xr
7
+ import torch
8
+ import torch.nn as nn
9
+ from scipy.signal import detrend, iirfilter, sosfilt, zpk2sos
10
+ from scipy.spatial import cKDTree
11
+ import pandas as pd
12
+ from loguru import logger
13
 
14
+ # GPU/CPU 設定
15
+ if torch.cuda.is_available():
16
+ device = torch.device("cuda")
17
+ logger.info("使用 GPU")
18
+ else:
19
+ device = torch.device("cpu")
20
+ logger.info("使用 CPU")
21
 
22
+ # 載入 Vs30 資料集(從 Hugging Face 下載)
23
+ from huggingface_hub import hf_hub_download
24
+
25
+ try:
26
+ logger.info("從 Hugging Face 載入 Vs30 資料...")
27
+ vs30_file = hf_hub_download(
28
+ repo_id="SeisBlue/TaiwanVs30",
29
+ filename="Vs30ofTaiwan.nc"
30
+ )
31
+ ds = xr.open_dataset(vs30_file)
32
+ lat_flat = ds['lat'].values.flatten()
33
+ lon_flat = ds['lon'].values.flatten()
34
+ vs30_flat = ds['vs30'].values.flatten()
35
+
36
+ vs30_table = pd.DataFrame({'lat': lat_flat, 'lon': lon_flat, 'Vs30': vs30_flat})
37
+ vs30_table = vs30_table.replace([np.inf, -np.inf], np.nan).dropna()
38
+ tree = cKDTree(vs30_table[["lat", "lon"]])
39
+ logger.info("Vs30 資料載入完成")
40
+ except Exception as e:
41
+ logger.error(f"Vs30 資料載入失敗: {e}")
42
+
43
+ # 載入目標測站
44
+ target_file = "station/eew_target.csv"
45
+ try:
46
+ logger.info(f"載入 {target_file}...")
47
+ target_df = pd.read_csv(target_file)
48
+ target_dict = target_df.to_dict(orient="records")
49
+ logger.info(f"{target_file} 載入完成")
50
+ except FileNotFoundError:
51
+ logger.error(f"{target_file} 找不到")
52
+
53
+ # 載入測站資訊
54
+ site_info_file = "station/site_info.txt"
55
+ try:
56
+ logger.info(f"載入 {site_info_file}...")
57
+ site_info = pd.read_csv(site_info_file)
58
+ logger.info(f"{site_info_file} 載入完成")
59
+ except FileNotFoundError:
60
+ logger.warning(f"{site_info_file} 找不到")
61
+
62
+ # 預設地震事件
63
+ EARTHQUAKE_EVENTS = {
64
+ "0403花蓮地震 (2024)": "waveform/20240403.mseed",
65
+ }
66
+
67
+
68
+ # ============ 模型定義(從 ttsam_realtime.py 複製) ============
69
+
70
+ class LambdaLayer(nn.Module):
71
+ def __init__(self, lambd, eps=1e-4):
72
+ super(LambdaLayer, self).__init__()
73
+ self.lambd = lambd
74
+ self.eps = eps
75
+
76
+ def forward(self, x):
77
+ return self.lambd(x) + self.eps
78
+
79
+
80
+ class MLP(nn.Module):
81
+ def __init__(self, input_shape, dims=(500, 300, 200, 150), activation=nn.ReLU(),
82
+ last_activation=None):
83
+ super(MLP, self).__init__()
84
+ if last_activation is None:
85
+ last_activation = activation
86
+ self.dims = dims
87
+ self.first_fc = nn.Linear(input_shape[0], dims[0])
88
+ self.first_activation = activation
89
+
90
+ more_hidden = []
91
+ if len(self.dims) > 2:
92
+ for i in range(1, len(self.dims) - 1):
93
+ more_hidden.append(nn.Linear(self.dims[i - 1], self.dims[i]))
94
+ more_hidden.append(nn.ReLU())
95
+
96
+ self.more_hidden = nn.ModuleList(more_hidden)
97
+ self.last_fc = nn.Linear(dims[-2], dims[-1])
98
+ self.last_activation = last_activation
99
+
100
+ def forward(self, x):
101
+ output = self.first_fc(x)
102
+ output = self.first_activation(output)
103
+ if self.more_hidden:
104
+ for layer in self.more_hidden:
105
+ output = layer(output)
106
+ output = self.last_fc(output)
107
+ output = self.last_activation(output)
108
+ return output
109
+
110
+
111
+ class CNN(nn.Module):
112
+ def __init__(self, input_shape=(-1, 6000, 3), activation=nn.ReLU(), downsample=1,
113
+ mlp_input=11665, mlp_dims=(500, 300, 200, 150), eps=1e-8):
114
+ super(CNN, self).__init__()
115
+ self.input_shape = input_shape
116
+ self.activation = activation
117
+ self.downsample = downsample
118
+ self.mlp_input = mlp_input
119
+ self.mlp_dims = mlp_dims
120
+ self.eps = eps
121
+
122
+ self.lambda_layer_1 = LambdaLayer(
123
+ lambda t: t / (
124
+ torch.max(torch.max(torch.abs(t), dim=1, keepdim=True).values,
125
+ dim=2, keepdim=True).values + self.eps)
126
+ )
127
+ self.unsqueeze_layer1 = LambdaLayer(lambda t: torch.unsqueeze(t, dim=1))
128
+ self.lambda_layer_2 = LambdaLayer(
129
+ lambda t: torch.log(torch.max(torch.max(torch.abs(t), dim=1).values,
130
+ dim=1).values + self.eps) / 100
131
+ )
132
+ self.unsqueeze_layer2 = LambdaLayer(lambda t: torch.unsqueeze(t, dim=1))
133
+ self.conv2d1 = nn.Sequential(
134
+ nn.Conv2d(1, 8, kernel_size=(1, downsample), stride=(1, downsample)),
135
+ nn.ReLU())
136
+ self.conv2d2 = nn.Sequential(
137
+ nn.Conv2d(8, 32, kernel_size=(16, 3), stride=(1, 3)), nn.ReLU())
138
+ self.conv1d1 = nn.Sequential(nn.Conv1d(32, 64, kernel_size=16), nn.ReLU())
139
+ self.maxpooling = nn.MaxPool1d(2)
140
+ self.conv1d2 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=16), nn.ReLU())
141
+ self.conv1d3 = nn.Sequential(nn.Conv1d(128, 32, kernel_size=8), nn.ReLU())
142
+ self.conv1d4 = nn.Sequential(nn.Conv1d(32, 32, kernel_size=8), nn.ReLU())
143
+ self.conv1d5 = nn.Sequential(nn.Conv1d(32, 16, kernel_size=4), nn.ReLU())
144
+ self.mlp = MLP((self.mlp_input,), dims=self.mlp_dims)
145
+
146
+ def forward(self, x):
147
+ output = self.lambda_layer_1(x)
148
+ output = self.unsqueeze_layer1(output)
149
+ scale = self.lambda_layer_2(x)
150
+ scale = self.unsqueeze_layer2(scale)
151
+ output = self.conv2d1(output)
152
+ output = self.conv2d2(output)
153
+ output = torch.squeeze(output, dim=-1)
154
+ output = self.conv1d1(output)
155
+ output = self.maxpooling(output)
156
+ output = self.conv1d2(output)
157
+ output = self.maxpooling(output)
158
+ output = self.conv1d3(output)
159
+ output = self.maxpooling(output)
160
+ output = self.conv1d4(output)
161
+ output = self.conv1d5(output)
162
+ output = torch.flatten(output, start_dim=1)
163
+ output = torch.cat((output, scale), dim=1)
164
+ output = self.mlp(output)
165
+ return output
166
+
167
+
168
+ class PositionEmbeddingVs30(nn.Module):
169
+ def __init__(self, wavelengths=((5, 30), (110, 123), (0.01, 5000), (100, 1600)),
170
+ emb_dim=500):
171
+ super(PositionEmbeddingVs30, self).__init__()
172
+ self.wavelengths = wavelengths
173
+ self.emb_dim = emb_dim
174
+
175
+ min_lat, max_lat = wavelengths[0]
176
+ min_lon, max_lon = wavelengths[1]
177
+ min_depth, max_depth = wavelengths[2]
178
+ min_vs30, max_vs30 = wavelengths[3]
179
+
180
+ assert emb_dim % 10 == 0
181
+ lat_dim = emb_dim // 5
182
+ lon_dim = emb_dim // 5
183
+ depth_dim = emb_dim // 10
184
+ vs30_dim = emb_dim // 10
185
+
186
+ self.lat_coeff = 2 * np.pi * 1.0 / min_lat * (
187
+ (min_lat / max_lat) ** (np.arange(lat_dim) / lat_dim))
188
+ self.lon_coeff = 2 * np.pi * 1.0 / min_lon * (
189
+ (min_lon / max_lon) ** (np.arange(lon_dim) / lon_dim))
190
+ self.depth_coeff = 2 * np.pi * 1.0 / min_depth * (
191
+ (min_depth / max_depth) ** (np.arange(depth_dim) / depth_dim))
192
+ self.vs30_coeff = 2 * np.pi * 1.0 / min_vs30 * (
193
+ (min_vs30 / max_vs30) ** (np.arange(vs30_dim) / vs30_dim))
194
+
195
+ lat_sin_mask = np.arange(emb_dim) % 5 == 0
196
+ lat_cos_mask = np.arange(emb_dim) % 5 == 1
197
+ lon_sin_mask = np.arange(emb_dim) % 5 == 2
198
+ lon_cos_mask = np.arange(emb_dim) % 5 == 3
199
+ depth_sin_mask = np.arange(emb_dim) % 10 == 4
200
+ depth_cos_mask = np.arange(emb_dim) % 10 == 9
201
+ vs30_sin_mask = np.arange(emb_dim) % 10 == 5
202
+ vs30_cos_mask = np.arange(emb_dim) % 10 == 8
203
+
204
+ self.mask = np.zeros(emb_dim)
205
+ self.mask[lat_sin_mask] = np.arange(lat_dim)
206
+ self.mask[lat_cos_mask] = lat_dim + np.arange(lat_dim)
207
+ self.mask[lon_sin_mask] = 2 * lat_dim + np.arange(lon_dim)
208
+ self.mask[lon_cos_mask] = 2 * lat_dim + lon_dim + np.arange(lon_dim)
209
+ self.mask[depth_sin_mask] = 2 * lat_dim + 2 * lon_dim + np.arange(depth_dim)
210
+ self.mask[depth_cos_mask] = 2 * lat_dim + 2 * lon_dim + depth_dim + np.arange(
211
+ depth_dim)
212
+ self.mask[
213
+ vs30_sin_mask] = 2 * lat_dim + 2 * lon_dim + 2 * depth_dim + np.arange(
214
+ vs30_dim)
215
+ self.mask[
216
+ vs30_cos_mask] = 2 * lat_dim + 2 * lon_dim + 2 * depth_dim + vs30_dim + np.arange(
217
+ vs30_dim)
218
+ self.mask = self.mask.astype("int32")
219
+
220
+ def forward(self, x):
221
+ lat_base = x[:, :, 0:1].to(device) * torch.Tensor(self.lat_coeff).to(device)
222
+ lon_base = x[:, :, 1:2].to(device) * torch.Tensor(self.lon_coeff).to(device)
223
+ depth_base = x[:, :, 2:3].to(device) * torch.Tensor(self.depth_coeff).to(device)
224
+ vs30_base = x[:, :, 3:4] * torch.Tensor(self.vs30_coeff).to(device)
225
+
226
+ output = torch.cat([
227
+ torch.sin(lat_base), torch.cos(lat_base),
228
+ torch.sin(lon_base), torch.cos(lon_base),
229
+ torch.sin(depth_base), torch.cos(depth_base),
230
+ torch.sin(vs30_base), torch.cos(vs30_base),
231
+ ], dim=-1)
232
+
233
+ maskk = torch.from_numpy(np.array(self.mask)).long()
234
+ index = (maskk.unsqueeze(0).unsqueeze(0)).expand(x.shape[0], 1,
235
+ self.emb_dim).to(device)
236
+ output = torch.gather(output, -1, index).to(device)
237
+ return output
238
+
239
+
240
+ class TransformerEncoder(nn.Module):
241
+ def __init__(self, d_model=150, nhead=10, batch_first=True, activation="gelu",
242
+ dropout=0.0, dim_feedforward=1000):
243
+ super(TransformerEncoder, self).__init__()
244
+ self.encoder_layer = nn.TransformerEncoderLayer(
245
+ d_model=d_model, nhead=nhead, batch_first=batch_first,
246
+ activation=activation, dropout=dropout, dim_feedforward=dim_feedforward
247
+ ).to(device)
248
+ self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, 6).to(
249
+ device)
250
+
251
+ def forward(self, x, src_key_padding_mask=None):
252
+ return self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
253
+
254
+
255
+ class MDN(nn.Module):
256
+ def __init__(self, input_shape=(150,), n_hidden=20, n_gaussians=5):
257
+ super(MDN, self).__init__()
258
+ self.z_h = nn.Sequential(nn.Linear(input_shape[0], n_hidden), nn.Tanh())
259
+ self.z_weight = nn.Linear(n_hidden, n_gaussians)
260
+ self.z_sigma = nn.Linear(n_hidden, n_gaussians)
261
+ self.z_mu = nn.Linear(n_hidden, n_gaussians)
262
+
263
+ def forward(self, x):
264
+ z_h = self.z_h(x)
265
+ weight = nn.functional.softmax(self.z_weight(z_h), -1)
266
+ sigma = torch.exp(self.z_sigma(z_h))
267
+ mu = self.z_mu(z_h)
268
+ return weight, sigma, mu
269
+
270
+
271
+ class FullModel(nn.Module):
272
+ def __init__(self, model_cnn, model_position, model_transformer, model_mlp,
273
+ model_mdn,
274
+ max_station=25, pga_targets=15, emb_dim=150, data_length=6000):
275
+ super(FullModel, self).__init__()
276
+ self.data_length = data_length
277
+ self.model_CNN = model_cnn
278
+ self.model_Position = model_position
279
+ self.model_Transformer = model_transformer
280
+ self.model_mlp = model_mlp
281
+ self.model_MDN = model_mdn
282
+ self.max_station = max_station
283
+ self.pga_targets = pga_targets
284
+ self.emb_dim = emb_dim
285
+
286
+ def forward(self, data):
287
+ cnn_output = self.model_CNN(
288
+ torch.DoubleTensor(
289
+ data["waveform"].reshape(-1, self.data_length, 3)).float().to(device)
290
+ )
291
+ cnn_output_reshape = torch.reshape(cnn_output,
292
+ (-1, self.max_station, self.emb_dim))
293
+
294
+ emb_output = self.model_Position(
295
+ torch.DoubleTensor(
296
+ data["station"].reshape(-1, 1, data["station"].shape[2])).float().to(
297
+ device)
298
+ )
299
+ emb_output = emb_output.reshape(-1, self.max_station, self.emb_dim)
300
+
301
+ station_pad_mask = data["station"] == 0
302
+ station_pad_mask = torch.all(station_pad_mask, 2)
303
+
304
+ pga_pos_emb_output = self.model_Position(
305
+ torch.DoubleTensor(
306
+ data["target"].reshape(-1, 1, data["target"].shape[2])).float().to(
307
+ device)
308
+ )
309
+ pga_pos_emb_output = pga_pos_emb_output.reshape(-1, self.pga_targets,
310
+ self.emb_dim)
311
+
312
+ target_pad_mask = torch.ones_like(data["target"], dtype=torch.bool)
313
+ target_pad_mask = torch.all(target_pad_mask, 2)
314
+ pad_mask = torch.cat((station_pad_mask, target_pad_mask), dim=1).to(device)
315
+
316
+ add_pe_cnn_output = torch.add(cnn_output_reshape, emb_output)
317
+ transformer_input = torch.cat((add_pe_cnn_output, pga_pos_emb_output), dim=1)
318
+ transformer_output = self.model_Transformer(transformer_input, pad_mask)
319
+
320
+ mlp_input = transformer_output[:, -self.pga_targets:, :].to(device)
321
+ mlp_output = self.model_mlp(mlp_input)
322
+ weight, sigma, mu = self.model_MDN(mlp_output)
323
+
324
+ return weight, sigma, mu
325
+
326
+
327
+ def get_full_model(model_path):
328
+ emb_dim = 150
329
+ mlp_dims = (150, 100, 50, 30, 10)
330
+ cnn_model = CNN(mlp_input=5665).to(device)
331
+ pos_emb_model = PositionEmbeddingVs30(emb_dim=emb_dim).to(device)
332
+ transformer_model = TransformerEncoder()
333
+ mlp_model = MLP(input_shape=(emb_dim,), dims=mlp_dims).to(device)
334
+ mdn_model = MDN(input_shape=(mlp_dims[-1],)).to(device)
335
+ full_model = FullModel(
336
+ cnn_model, pos_emb_model, transformer_model, mlp_model, mdn_model,
337
+ pga_targets=25, data_length=3000
338
+ ).to(device)
339
+ full_model.load_state_dict(
340
+ torch.load(model_path, weights_only=True, map_location=device))
341
+ return full_model
342
+
343
+
344
+ # 載入模型
345
+ model_path = hf_hub_download(
346
+ repo_id="SeisBlue/TTSAM",
347
+ filename="ttsam_trained_model_11.pt"
348
+ )
349
+ model = get_full_model(model_path)
350
+
351
+
352
+ # ============ 輔助函數 ============
353
+
354
+ def lowpass(data, freq=10, df=100, corners=4):
355
+ fe = 0.5 * df
356
+ f = freq / fe
357
+ if f > 1:
358
+ f = 1.0
359
+ z, p, k = iirfilter(corners, f, btype="lowpass", ftype="butter", output="zpk")
360
+ sos = zpk2sos(z, p, k)
361
+ return sosfilt(sos, data)
362
+
363
+
364
+ def signal_processing(waveform):
365
+ data = detrend(waveform, type="constant")
366
+ data = lowpass(data, freq=10)
367
+ return data
368
+
369
+
370
+ def get_vs30(lat, lon):
371
+ distance, i = tree.query([float(lat), float(lon)])
372
+ vs30 = vs30_table.iloc[i]["Vs30"]
373
+ return float(vs30)
374
+
375
+
376
+ def get_station_position(station):
377
+ latitude, longitude, elevation = site_info.loc[
378
+ (site_info["Station"] == station), ["Latitude", "Longitude", "Elevation"]
379
+ ].values[0]
380
+ return latitude, longitude, elevation
381
+
382
+
383
+ def calculate_intensity(pga, label=False):
384
+ intensity_label = ["0", "1", "2", "3", "4", "5-", "5+", "6-", "6+", "7"]
385
+ pga_level = np.log10([1e-5, 0.008, 0.025, 0.080, 0.250, 0.80, 1.4, 2.5, 4.4, 8.0])
386
+
387
+ pga_intensity = np.searchsorted(pga_level, pga) - 1
388
+ intensity = pga_intensity
389
+
390
+ if label:
391
+ return intensity_label[intensity]
392
+ else:
393
+ return intensity
394
+
395
+
396
+ # ============ Gradio 介面函數 ============
397
+
398
+ def load_waveform(event_name):
399
+ file_path = EARTHQUAKE_EVENTS[event_name]
400
+ st = read(file_path)
401
+ tr = st[0]
402
+ times = tr.times()
403
+ data = tr.data
404
+ return times, data, tr.stats.sampling_rate
405
+
406
+
407
+ def plot_waveform(times, data, start_time, end_time, sampling_rate):
408
+ fig, ax = plt.subplots(figsize=(12, 3))
409
+ ax.plot(times, data, 'gray', linewidth=0.5, alpha=0.6)
410
+
411
+ mask = (times >= start_time) & (times <= end_time)
412
+ ax.plot(times[mask], data[mask], 'blue', linewidth=1)
413
+
414
+ ax.axvline(start_time, color='red', linestyle='--', linewidth=1)
415
+ ax.axvline(end_time, color='red', linestyle='--', linewidth=1)
416
+
417
+ ax.set_xlabel('Time (s)')
418
+ ax.set_ylabel('Amplitude')
419
+ ax.set_title('Seismic Waveform')
420
+ ax.grid(True, alpha=0.3)
421
+
422
+ return fig
423
+
424
+
425
+ def plot_intensity_map(pga_list, target_names):
426
+ fig, ax = plt.subplots(figsize=(6, 8))
427
+
428
+ # 繪製台灣地圖底圖
429
+ taiwan_lon = [120, 122]
430
+ taiwan_lat = [22, 25]
431
+
432
+ # 根據 target_names 取得座標
433
+ lats, lons, intensities = [], [], []
434
+ for i, target_name in enumerate(target_names):
435
+ target = next((t for t in target_dict if t["station"] == target_name), None)
436
+ if target:
437
+ lats.append(target["latitude"])
438
+ lons.append(target["longitude"])
439
+ intensities.append(calculate_intensity(pga_list[i]))
440
+
441
+ # 繪製散點圖
442
+ scatter = ax.scatter(lons, lats, c=intensities, cmap='YlOrRd', s=100,
443
+ vmin=0, vmax=7, edgecolors='black', linewidth=0.5)
444
+
445
+ ax.set_xlabel('Longitude')
446
+ ax.set_ylabel('Latitude')
447
+ ax.set_title('Predicted Intensity Distribution')
448
+ ax.set_xlim(taiwan_lon)
449
+ ax.set_ylim(taiwan_lat)
450
+
451
+ cbar = plt.colorbar(scatter, ax=ax)
452
+ cbar.set_label('Intensity')
453
+
454
+ return fig
455
+
456
+
457
+ def predict_intensity(event_name, start_time, end_time, lon, lat):
458
+ # 1. 載入波形
459
+ times, data, sampling_rate = load_waveform(event_name)
460
+
461
+ # 2. 切片波形
462
+ start_idx = int(start_time * sampling_rate)
463
+ end_idx = int(end_time * sampling_rate)
464
+ waveform_slice = data[start_idx:end_idx]
465
+
466
+ # 3. 訊號處理
467
+ waveform_processed = signal_processing(waveform_slice)
468
+
469
+ # 4. 準備模型輸入
470
+ # 假設單測站三軸資料(這裡簡化為重複使用Z軸)
471
+ waveform_3c = np.array(
472
+ [[waveform_processed, waveform_processed, waveform_processed]])
473
+ waveform_3c = waveform_3c.transpose(0, 2, 1) # (1, 3000, 3)
474
+
475
+ # 準備測站資訊
476
+ vs30 = get_vs30(lat, lon)
477
+ station_info_input = np.array([[lat, lon, 100, vs30]]) # elevation 假設 100m
478
+
479
+ # 準備目標測站資訊
480
+ target_list = []
481
+ target_names = []
482
+ for target in target_dict[:25]: # 限制25個目標
483
+ target_list.append([target["latitude"], target["longitude"],
484
+ target["elevation"],
485
+ get_vs30(target["latitude"], target["longitude"])])
486
+ target_names.append(target["station"])
487
+
488
+ # 組合成 tensor
489
+ tensor_data = {
490
+ "waveform": torch.tensor(waveform_3c).unsqueeze(0).double(),
491
+ "station": torch.tensor(station_info_input).unsqueeze(0).double(),
492
+ "target": torch.tensor(target_list).unsqueeze(0).double(),
493
+ }
494
+
495
+ # 5. 執行預測
496
+ with torch.no_grad():
497
+ weight, sigma, mu = model(tensor_data)
498
+ pga_list = torch.sum(weight * mu,
499
+ dim=2).cpu().detach().numpy().flatten().tolist()
500
+
501
+ # 6. 繪製結果
502
+ waveform_plot = plot_waveform(times, data, start_time, end_time, sampling_rate)
503
+ intensity_plot = plot_intensity_map(pga_list, target_names)
504
+
505
+ # 統計資訊
506
+ max_intensity = max([calculate_intensity(pga, label=True) for pga in pga_list])
507
+ stats = f"選取時間範圍: {start_time:.1f} - {end_time:.1f} 秒\n"
508
+ stats += f"測站位置: ({lon:.4f}, {lat:.4f})\n"
509
+ stats += f"預測最大震度: {max_intensity}"
510
+
511
+ return waveform_plot, intensity_plot, stats
512
+
513
+
514
+ # ============ Gradio 介面 ============
515
+
516
+ with gr.Blocks(title="TTSAM 震度預測系統") as demo:
517
+ gr.Markdown("# 🌏 TTSAM 震度預測系統")
518
+
519
+ with gr.Row():
520
+ # 左側:輸入控制區
521
+ with gr.Column(scale=1):
522
+ gr.Markdown("## 輸入設定")
523
+ event_dropdown = gr.Dropdown(
524
+ choices=list(EARTHQUAKE_EVENTS.keys()),
525
+ value=list(EARTHQUAKE_EVENTS.keys())[0],
526
+ label="選擇地震事件"
527
+ )
528
+
529
+ with gr.Row():
530
+ start_slider = gr.Slider(0, 300, value=0, step=1, label="起始時間 (秒)")
531
+ end_slider = gr.Slider(0, 300, value=30, step=1, label="結束時間 (秒)")
532
+
533
+ gr.Markdown("### 測站位置")
534
+ with gr.Row():
535
+ lon_input = gr.Number(value=121.5, label="經度")
536
+ lat_input = gr.Number(value=24.0, label="緯度")
537
+
538
+ predict_btn = gr.Button("🔮 執行預測", variant="primary")
539
+
540
+ # 右側:震度分布圖
541
+ with gr.Column(scale=1):
542
+ gr.Markdown("## 預測震度分布")
543
+ intensity_plot = gr.Plot(label="震度分布圖")
544
+ stats_output = gr.Textbox(label="預測統計", lines=3)
545
+
546
+ # 下方:波形圖
547
+ with gr.Row():
548
+ gr.Markdown("## 輸入波形")
549
+ with gr.Row():
550
+ waveform_plot = gr.Plot(label="地震波形")
551
+
552
+ # 綁定事件
553
+ predict_btn.click(
554
+ fn=predict_intensity,
555
+ inputs=[event_dropdown, start_slider, end_slider, lon_input, lat_input],
556
+ outputs=[waveform_plot, intensity_plot, stats_output]
557
+ )
558
+
559
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ datasets
4
+ torch
5
+ obspy
6
+ numpy
7
+ matplotlib
8
+ xarray
9
+ netCDF4
10
+ scipy
11
+ pandas
12
+ loguru
13
+ huggingface_hub
station/eew_target.csv ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ network,county,station,station_zh,longitude,latitude,elevation
2
+ CWB_SMT,臺北市,TAP,臺北地震站,121.514,25.038,16
3
+ TSMIP,新北市,A024,板橋地震站,121.475,25.019,14
4
+ CWASN,新北市,NTS,淡水地震站,121.449,25.164,15
5
+ CWASN,新北市,TIPB,雙溪地震站,121.826,24.972,399
6
+ CWASN,基隆市,NOU,基隆地震站,121.773,25.149,16
7
+ CWB_SMT,桃園市,NTY,桃園地震站,121.298,25.000,93
8
+ CWASN,桃園市,NCU,中壢地震站,121.187,24.967,131
9
+ TSMIP,桃園市,B011,大溪地震站,121.286,24.884,117
10
+ CWASN,新竹市,HSN1,新竹地震站,121.018,24.779,91
11
+ CWASN,新竹縣,HSN,竹北地震站,121.014,24.828,31
12
+ CWASN,新竹縣,NJD,竹東地震站,121.088,24.736,131
13
+ TSMIP,苗栗縣,B131,苗栗地震站,120.826,24.565,50
14
+ CWB_SMT,苗栗縣,TWQ1,鯉魚潭地震站,120.781,24.346,286
15
+ TSMIP,苗栗縣,B045,(沒有泰安)獅潭地震站,120.9206,24.5399,201
16
+ CWASN,臺中市,TCU,臺中地震站,120.684,24.146,89
17
+ CWASN,臺中市,WDJ,大甲地震站,120.640,24.348,99
18
+ CWA,臺中市,WHP,烏石坑地震站,120.946,24.278,934
19
+ CWB_SMT,南投縣,WNT1,南投地震站,120.680,23.907,118
20
+ CWASN,南投縣,WPL,埔里地震站,120.957,24.012,465
21
+ CWASN,南投縣,WHY,信義地震站,120.853,23.696,495
22
+ CWASN,彰化縣,WCHH,彰化地震站,120.558,24.079,25
23
+ CWASN,彰化縣,WYL,員林地震站,120.580,23.960,33
24
+ CWASN,雲林縣,WDL,斗六地震站,120.539,23.715,52
25
+ CWASN,雲林縣,WSL,水林地震站,120.228,23.523,3
26
+ CWASN,嘉義市,CHY1,嘉義地震站,120.433,23.496,31
27
+ TSMIP,嘉義縣,C095,太保地震站,23.46,120.29,10
28
+ CWASN,嘉義縣,WCKO,番路地震站,120.605,23.439,233
29
+ CWASN,臺南市,TAI,臺南地震站,120.205,22.993,19
30
+ TSMIP,臺南市,C015,白河地震站,120.414,23.353,39
31
+ CWB_SMT,臺南市,CHN1,楠西地震站,120.529,23.185,216
32
+ CWASN,高雄市,KAU,前鎮地震站,22.5662,120.3157,1
33
+ CWASN,高雄市,SCS,旗山地震站,120.494,22.885,70
34
+ CWASN,屏東縣,SPT,屏東地震站,120.496,22.677,29
35
+ CWASN,屏東縣,HEN,恆春地震站,120.746,22.004,26
36
+ CWASN,屏東縣,SSD,三地門地震站,120.640,22.744,148
37
+ CWASN,宜蘭縣,ILA,宜蘭地震站,121.756,24.764,11
38
+ CWB_SMT,宜蘭縣,TWC,蘇澳地震站,121.860,24.608,33
39
+ CWB_SMT,宜蘭縣,ENT,牛鬥地震站,121.574,24.638,252
40
+ CWASN,花蓮縣,HWA,花蓮地震站,121.613,23.975,18
41
+ CWASN,花蓮縣,EGFH,光復地震站,121.427,23.669,126
42
+ CWASN,花蓮縣,EYUL,玉里地震站,121.319,23.347,138
43
+ CWASN,臺東縣,TTN,臺東地震站,121.155,22.752,12
44
+ CWASN,臺東縣,ECS,池上地震站,121.219,23.095,286
45
+ CWASN,臺東縣,TAWH,大武地震站,120.888,22.340,24
46
+ CWASN,澎湖縣,PNG,馬公地震站,119.564,23.565,10
47
+ CWASN,金門縣,KNM,金門地震站,118.289,24.407,31
48
+ CWASN,連江縣,MSU,馬祖地震站,119.923,26.169,85
station/site_info.txt ADDED
The diff for this file is too large to render. See raw diff