Spaces:
Running
Running
Commit
·
12f3211
1
Parent(s):
cbfa1d2
refactor intensity prediction functions and update Gradio interface for epicenter input
Browse files
app.py
CHANGED
|
@@ -405,28 +405,166 @@ def calculate_intensity(pga, label=False):
|
|
| 405 |
# ============ Gradio 介面函數 ============
|
| 406 |
|
| 407 |
def load_waveform(event_name):
|
|
|
|
| 408 |
file_path = EARTHQUAKE_EVENTS[event_name]
|
| 409 |
st = read(file_path)
|
| 410 |
-
|
| 411 |
-
times = tr.times()
|
| 412 |
-
data = tr.data
|
| 413 |
-
return times, data, tr.stats.sampling_rate
|
| 414 |
|
| 415 |
|
| 416 |
-
def
|
| 417 |
-
|
| 418 |
-
|
| 419 |
|
| 420 |
-
mask = (times >= start_time) & (times <= end_time)
|
| 421 |
-
ax.plot(times[mask], data[mask], 'blue', linewidth=1)
|
| 422 |
|
| 423 |
-
|
| 424 |
-
|
|
|
|
| 425 |
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
return fig
|
| 432 |
|
|
@@ -463,61 +601,83 @@ def plot_intensity_map(pga_list, target_names):
|
|
| 463 |
return fig
|
| 464 |
|
| 465 |
|
| 466 |
-
def predict_intensity(event_name, start_time, end_time,
|
| 467 |
-
|
| 468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
|
| 522 |
|
| 523 |
# ============ Gradio 介面 ============
|
|
@@ -539,11 +699,12 @@ with gr.Blocks(title="TTSAM 震度預測系統") as demo:
|
|
| 539 |
start_slider = gr.Slider(0, 300, value=0, step=1, label="起始時間 (秒)")
|
| 540 |
end_slider = gr.Slider(0, 300, value=30, step=1, label="結束時間 (秒)")
|
| 541 |
|
| 542 |
-
gr.Markdown("###
|
| 543 |
with gr.Row():
|
| 544 |
-
|
| 545 |
-
|
| 546 |
|
|
|
|
| 547 |
with gr.Row():
|
| 548 |
vs30_input = gr.Number(
|
| 549 |
value=600,
|
|
@@ -553,6 +714,13 @@ with gr.Blocks(title="TTSAM 震度預測系統") as demo:
|
|
| 553 |
|
| 554 |
predict_btn = gr.Button("🔮 執行預測", variant="primary")
|
| 555 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
# 右側:震度分布圖
|
| 557 |
with gr.Column(scale=1):
|
| 558 |
gr.Markdown("## 預測震度分布")
|
|
@@ -568,7 +736,7 @@ with gr.Blocks(title="TTSAM 震度預測系統") as demo:
|
|
| 568 |
# 綁定事件
|
| 569 |
predict_btn.click(
|
| 570 |
fn=predict_intensity,
|
| 571 |
-
inputs=[event_dropdown, start_slider, end_slider,
|
| 572 |
outputs=[waveform_plot, intensity_plot, stats_output]
|
| 573 |
)
|
| 574 |
|
|
|
|
| 405 |
# ============ Gradio 介面函數 ============
|
| 406 |
|
| 407 |
def load_waveform(event_name):
|
| 408 |
+
"""載入完整的 mseed 檔案(包含所有測站)"""
|
| 409 |
file_path = EARTHQUAKE_EVENTS[event_name]
|
| 410 |
st = read(file_path)
|
| 411 |
+
return st
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
|
| 414 |
+
def calculate_distance(lat1, lon1, lat2, lon2):
|
| 415 |
+
"""計算兩點間的距離(簡化的平面距離,單位:度)"""
|
| 416 |
+
return np.sqrt((lat1 - lat2)**2 + (lon1 - lon2)**2)
|
| 417 |
|
|
|
|
|
|
|
| 418 |
|
| 419 |
+
def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25):
|
| 420 |
+
"""選擇距離震央最近的 n 個測站"""
|
| 421 |
+
station_distances = []
|
| 422 |
|
| 423 |
+
# 計算每個測站到震央的距離
|
| 424 |
+
for tr in st:
|
| 425 |
+
station_code = tr.stats.station
|
| 426 |
+
|
| 427 |
+
# 從 site_info 中查詢測站位置
|
| 428 |
+
try:
|
| 429 |
+
station_data = site_info[site_info["Station"] == station_code]
|
| 430 |
+
if len(station_data) == 0:
|
| 431 |
+
continue
|
| 432 |
+
|
| 433 |
+
lat = station_data["Latitude"].values[0]
|
| 434 |
+
lon = station_data["Longitude"].values[0]
|
| 435 |
+
elev = station_data["Elevation"].values[0]
|
| 436 |
+
|
| 437 |
+
distance = calculate_distance(epicenter_lat, epicenter_lon, lat, lon)
|
| 438 |
+
station_distances.append({
|
| 439 |
+
"station": station_code,
|
| 440 |
+
"distance": distance,
|
| 441 |
+
"latitude": lat,
|
| 442 |
+
"longitude": lon,
|
| 443 |
+
"elevation": elev
|
| 444 |
+
})
|
| 445 |
+
except Exception as e:
|
| 446 |
+
logger.warning(f"測站 {station_code} 資訊查詢失敗: {e}")
|
| 447 |
+
continue
|
| 448 |
+
|
| 449 |
+
# 按距離排序並選擇最近的 n 個
|
| 450 |
+
station_distances.sort(key=lambda x: x["distance"])
|
| 451 |
+
selected_stations = station_distances[:n_stations]
|
| 452 |
+
|
| 453 |
+
logger.info(f"選擇了 {len(selected_stations)} 個最近的測站")
|
| 454 |
+
return selected_stations
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def extract_waveforms_from_stream(st, selected_stations, start_time, end_time, vs30_input):
|
| 458 |
+
"""從 Stream 中提取選定測站的波形資料"""
|
| 459 |
+
waveforms = []
|
| 460 |
+
station_info_list = []
|
| 461 |
+
valid_stations = []
|
| 462 |
+
|
| 463 |
+
sampling_rate = 100 # 假設 100 Hz
|
| 464 |
+
start_idx = int(start_time * sampling_rate)
|
| 465 |
+
end_idx = int(end_time * sampling_rate)
|
| 466 |
+
target_length = 3000
|
| 467 |
+
|
| 468 |
+
for station_data in selected_stations:
|
| 469 |
+
station_code = station_data["station"]
|
| 470 |
+
|
| 471 |
+
try:
|
| 472 |
+
# 選擇該測站的所有分量
|
| 473 |
+
st_station = st.select(station=station_code)
|
| 474 |
+
|
| 475 |
+
if len(st_station) == 0:
|
| 476 |
+
continue
|
| 477 |
+
|
| 478 |
+
# 嘗試取得 Z, N, E 分量
|
| 479 |
+
z_trace = st_station.select(component="Z")
|
| 480 |
+
n_trace = st_station.select(component="N") or st_station.select(component="1")
|
| 481 |
+
e_trace = st_station.select(component="E") or st_station.select(component="2")
|
| 482 |
+
|
| 483 |
+
# 如果沒有三分量,使用 Z 分量重複
|
| 484 |
+
if len(z_trace) > 0:
|
| 485 |
+
z_data = z_trace[0].data[start_idx:end_idx]
|
| 486 |
+
else:
|
| 487 |
+
continue
|
| 488 |
+
|
| 489 |
+
if len(n_trace) > 0:
|
| 490 |
+
n_data = n_trace[0].data[start_idx:end_idx]
|
| 491 |
+
else:
|
| 492 |
+
n_data = z_data.copy()
|
| 493 |
+
|
| 494 |
+
if len(e_trace) > 0:
|
| 495 |
+
e_data = e_trace[0].data[start_idx:end_idx]
|
| 496 |
+
else:
|
| 497 |
+
e_data = z_data.copy()
|
| 498 |
+
|
| 499 |
+
# 訊號處理
|
| 500 |
+
z_data = signal_processing(z_data)
|
| 501 |
+
n_data = signal_processing(n_data)
|
| 502 |
+
e_data = signal_processing(e_data)
|
| 503 |
+
|
| 504 |
+
# 調整長度到 3000
|
| 505 |
+
for data in [z_data, n_data, e_data]:
|
| 506 |
+
if len(data) > target_length:
|
| 507 |
+
data = data[:target_length]
|
| 508 |
+
elif len(data) < target_length:
|
| 509 |
+
data = np.pad(data, (0, target_length - len(data)))
|
| 510 |
+
|
| 511 |
+
# 組合三分量 (3000, 3)
|
| 512 |
+
waveform_3c = np.stack([z_data[:target_length],
|
| 513 |
+
n_data[:target_length],
|
| 514 |
+
e_data[:target_length]], axis=1)
|
| 515 |
+
waveforms.append(waveform_3c)
|
| 516 |
+
|
| 517 |
+
# 準備測站資訊
|
| 518 |
+
vs30 = get_vs30(station_data["latitude"], station_data["longitude"], vs30_input)
|
| 519 |
+
station_info_list.append([
|
| 520 |
+
station_data["latitude"],
|
| 521 |
+
station_data["longitude"],
|
| 522 |
+
station_data["elevation"],
|
| 523 |
+
vs30
|
| 524 |
+
])
|
| 525 |
+
valid_stations.append(station_data)
|
| 526 |
+
|
| 527 |
+
except Exception as e:
|
| 528 |
+
logger.warning(f"測站 {station_code} 波形提取失敗: {e}")
|
| 529 |
+
continue
|
| 530 |
+
|
| 531 |
+
logger.info(f"成功提取 {len(waveforms)} 個測站的波形")
|
| 532 |
+
return waveforms, station_info_list, valid_stations
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def plot_waveform(st, selected_stations, start_time, end_time):
|
| 536 |
+
"""繪製選定測站的波形圖(堆疊顯示前 10 個測站)"""
|
| 537 |
+
fig, axes = plt.subplots(min(10, len(selected_stations)), 1, figsize=(12, 8), sharex=True)
|
| 538 |
+
|
| 539 |
+
if len(selected_stations) == 1:
|
| 540 |
+
axes = [axes]
|
| 541 |
+
|
| 542 |
+
for i, station_data in enumerate(selected_stations[:10]):
|
| 543 |
+
station_code = station_data["station"]
|
| 544 |
+
|
| 545 |
+
try:
|
| 546 |
+
st_station = st.select(station=station_code)
|
| 547 |
+
if len(st_station) > 0:
|
| 548 |
+
tr = st_station[0]
|
| 549 |
+
times = tr.times()
|
| 550 |
+
data = tr.data
|
| 551 |
+
|
| 552 |
+
axes[i].plot(times, data, 'black', linewidth=0.5)
|
| 553 |
+
|
| 554 |
+
# 標記選取範圍
|
| 555 |
+
axes[i].axvline(start_time, color='red', linestyle='--', linewidth=1, alpha=0.7)
|
| 556 |
+
axes[i].axvline(end_time, color='red', linestyle='--', linewidth=1, alpha=0.7)
|
| 557 |
+
axes[i].axvspan(start_time, end_time, alpha=0.2, color='blue')
|
| 558 |
+
|
| 559 |
+
axes[i].set_ylabel(f'{station_code}\n({station_data["distance"]:.2f}°)', fontsize=8)
|
| 560 |
+
axes[i].grid(True, alpha=0.3)
|
| 561 |
+
axes[i].tick_params(labelsize=8)
|
| 562 |
+
except Exception as e:
|
| 563 |
+
logger.warning(f"無法繪製測站 {station_code}: {e}")
|
| 564 |
+
|
| 565 |
+
axes[-1].set_xlabel('Time (s)')
|
| 566 |
+
fig.suptitle(f'波形記錄(前 10 個最近測站,共選擇 {len(selected_stations)} 個)', fontsize=12)
|
| 567 |
+
plt.tight_layout()
|
| 568 |
|
| 569 |
return fig
|
| 570 |
|
|
|
|
| 601 |
return fig
|
| 602 |
|
| 603 |
|
| 604 |
+
def predict_intensity(event_name, start_time, end_time, epicenter_lon, epicenter_lat, vs30_input):
|
| 605 |
+
"""執行震度預測"""
|
| 606 |
+
try:
|
| 607 |
+
# 1. 載入完整的 mseed 檔案
|
| 608 |
+
logger.info(f"載入地震事件: {event_name}")
|
| 609 |
+
st = load_waveform(event_name)
|
| 610 |
+
logger.info(f"載入了 {len(st)} 個 trace")
|
| 611 |
|
| 612 |
+
# 2. 根據震央距離選擇最近的 25 個測站
|
| 613 |
+
logger.info(f"選擇距離震央 ({epicenter_lat}, {epicenter_lon}) 最近的測站...")
|
| 614 |
+
selected_stations = select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25)
|
| 615 |
+
|
| 616 |
+
if len(selected_stations) == 0:
|
| 617 |
+
return None, None, "錯誤:找不到有效的測站資料"
|
| 618 |
+
|
| 619 |
+
# 3. 從選定的測站提取波形
|
| 620 |
+
logger.info(f"提取波形資料(時間範圍: {start_time}-{end_time} 秒)...")
|
| 621 |
+
waveforms, station_info_list, valid_stations = extract_waveforms_from_stream(
|
| 622 |
+
st, selected_stations, start_time, end_time, vs30_input
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
if len(waveforms) == 0:
|
| 626 |
+
return None, None, "錯誤:無法提取波形資料"
|
| 627 |
+
|
| 628 |
+
# 4. Padding 到 25 個測站(模型要求)
|
| 629 |
+
max_stations = 25
|
| 630 |
+
waveform_padded = np.zeros((max_stations, 3000, 3))
|
| 631 |
+
station_info_padded = np.zeros((max_stations, 4))
|
| 632 |
+
|
| 633 |
+
for i in range(min(len(waveforms), max_stations)):
|
| 634 |
+
waveform_padded[i] = waveforms[i]
|
| 635 |
+
station_info_padded[i] = station_info_list[i]
|
| 636 |
+
|
| 637 |
+
# 5. 準備目標測站資訊
|
| 638 |
+
target_list = []
|
| 639 |
+
target_names = []
|
| 640 |
+
for target in target_dict[:25]:
|
| 641 |
+
target_list.append([
|
| 642 |
+
target["latitude"],
|
| 643 |
+
target["longitude"],
|
| 644 |
+
target["elevation"],
|
| 645 |
+
get_vs30(target["latitude"], target["longitude"], vs30_input)
|
| 646 |
+
])
|
| 647 |
+
target_names.append(target["station"])
|
| 648 |
+
|
| 649 |
+
# 6. 組合成 tensor
|
| 650 |
+
tensor_data = {
|
| 651 |
+
"waveform": torch.tensor(waveform_padded).unsqueeze(0).double(),
|
| 652 |
+
"station": torch.tensor(station_info_padded).unsqueeze(0).double(),
|
| 653 |
+
"target": torch.tensor(target_list).unsqueeze(0).double(),
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
# 7. 執行預測
|
| 657 |
+
logger.info("執行模型預測...")
|
| 658 |
+
with torch.no_grad():
|
| 659 |
+
weight, sigma, mu = model(tensor_data)
|
| 660 |
+
pga_list = torch.sum(weight * mu, dim=2).cpu().detach().numpy().flatten().tolist()
|
| 661 |
+
|
| 662 |
+
# 8. 繪製結果
|
| 663 |
+
waveform_plot = plot_waveform(st, selected_stations, start_time, end_time)
|
| 664 |
+
intensity_plot = plot_intensity_map(pga_list, target_names)
|
| 665 |
+
|
| 666 |
+
# 9. 統計資訊
|
| 667 |
+
max_intensity = max([calculate_intensity(pga, label=True) for pga in pga_list])
|
| 668 |
+
stats = f"選取時間範圍: {start_time:.1f} - {end_time:.1f} 秒\n"
|
| 669 |
+
stats += f"震央位置: ({epicenter_lon:.4f}, {epicenter_lat:.4f})\n"
|
| 670 |
+
stats += f"使用測站數: {len(waveforms)} / 25\n"
|
| 671 |
+
stats += f"預測最大震度: {max_intensity}"
|
| 672 |
+
|
| 673 |
+
logger.info("預測完成!")
|
| 674 |
+
return waveform_plot, intensity_plot, stats
|
| 675 |
+
|
| 676 |
+
except Exception as e:
|
| 677 |
+
logger.error(f"預測過程發生錯誤: {e}")
|
| 678 |
+
import traceback
|
| 679 |
+
traceback.print_exc()
|
| 680 |
+
return None, None, f"錯誤: {str(e)}"
|
| 681 |
|
| 682 |
|
| 683 |
# ============ Gradio 介面 ============
|
|
|
|
| 699 |
start_slider = gr.Slider(0, 300, value=0, step=1, label="起始時間 (秒)")
|
| 700 |
end_slider = gr.Slider(0, 300, value=30, step=1, label="結束時間 (秒)")
|
| 701 |
|
| 702 |
+
gr.Markdown("### 震央位置")
|
| 703 |
with gr.Row():
|
| 704 |
+
epicenter_lon_input = gr.Number(value=121.5, label="震央經度")
|
| 705 |
+
epicenter_lat_input = gr.Number(value=24.0, label="震央緯度")
|
| 706 |
|
| 707 |
+
gr.Markdown("### 場址參數")
|
| 708 |
with gr.Row():
|
| 709 |
vs30_input = gr.Number(
|
| 710 |
value=600,
|
|
|
|
| 714 |
|
| 715 |
predict_btn = gr.Button("🔮 執行預測", variant="primary")
|
| 716 |
|
| 717 |
+
gr.Markdown("""
|
| 718 |
+
### 說明
|
| 719 |
+
- 系統會根據震央位置自動選擇最近的 25 個測站
|
| 720 |
+
- 從選定的時間範圍提取波形資料(30 秒)
|
| 721 |
+
- 預測全台灣目標測站的震度分布
|
| 722 |
+
""")
|
| 723 |
+
|
| 724 |
# 右側:震度分布圖
|
| 725 |
with gr.Column(scale=1):
|
| 726 |
gr.Markdown("## 預測震度分布")
|
|
|
|
| 736 |
# 綁定事件
|
| 737 |
predict_btn.click(
|
| 738 |
fn=predict_intensity,
|
| 739 |
+
inputs=[event_dropdown, start_slider, end_slider, epicenter_lon_input, epicenter_lat_input, vs30_input],
|
| 740 |
outputs=[waveform_plot, intensity_plot, stats_output]
|
| 741 |
)
|
| 742 |
|