jimmy60504 commited on
Commit
12f3211
·
1 Parent(s): cbfa1d2

refactor intensity prediction functions and update Gradio interface for epicenter input

Browse files
Files changed (1) hide show
  1. app.py +241 -73
app.py CHANGED
@@ -405,28 +405,166 @@ def calculate_intensity(pga, label=False):
405
  # ============ Gradio 介面函數 ============
406
 
407
  def load_waveform(event_name):
 
408
  file_path = EARTHQUAKE_EVENTS[event_name]
409
  st = read(file_path)
410
- tr = st[0]
411
- times = tr.times()
412
- data = tr.data
413
- return times, data, tr.stats.sampling_rate
414
 
415
 
416
- def plot_waveform(times, data, start_time, end_time, sampling_rate):
417
- fig, ax = plt.subplots(figsize=(12, 3))
418
- ax.plot(times, data, 'gray', linewidth=0.5, alpha=0.6)
419
 
420
- mask = (times >= start_time) & (times <= end_time)
421
- ax.plot(times[mask], data[mask], 'blue', linewidth=1)
422
 
423
- ax.axvline(start_time, color='red', linestyle='--', linewidth=1)
424
- ax.axvline(end_time, color='red', linestyle='--', linewidth=1)
 
425
 
426
- ax.set_xlabel('Time (s)')
427
- ax.set_ylabel('Amplitude')
428
- ax.set_title('Seismic Waveform')
429
- ax.grid(True, alpha=0.3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  return fig
432
 
@@ -463,61 +601,83 @@ def plot_intensity_map(pga_list, target_names):
463
  return fig
464
 
465
 
466
- def predict_intensity(event_name, start_time, end_time, lon, lat, vs30_input):
467
- # 1. 載入波形
468
- times, data, sampling_rate = load_waveform(event_name)
 
 
 
 
469
 
470
- # 2. 切片波形
471
- start_idx = int(start_time * sampling_rate)
472
- end_idx = int(end_time * sampling_rate)
473
- waveform_slice = data[start_idx:end_idx]
474
-
475
- # 3. 訊號處理
476
- waveform_processed = signal_processing(waveform_slice)
477
-
478
- # 4. 準備模型輸入
479
- # 假設單測站三軸資料(這裡簡化為重複使用Z軸)
480
- waveform_3c = np.array(
481
- [[waveform_processed, waveform_processed, waveform_processed]])
482
- waveform_3c = waveform_3c.transpose(0, 2, 1) # (1, 3000, 3)
483
-
484
- # 準備測站資訊(使用使用者輸入的 Vs30 或資料庫值)
485
- vs30 = get_vs30(lat, lon, vs30_input)
486
- station_info_input = np.array([[lat, lon, 100, vs30]]) # elevation 假設 100m
487
-
488
- # 準備目標測站資訊
489
- target_list = []
490
- target_names = []
491
- for target in target_dict[:25]: # 限制25個目標
492
- target_list.append([target["latitude"], target["longitude"],
493
- target["elevation"],
494
- get_vs30(target["latitude"], target["longitude"], vs30_input)])
495
- target_names.append(target["station"])
496
-
497
- # 組合成 tensor
498
- tensor_data = {
499
- "waveform": torch.tensor(waveform_3c).unsqueeze(0).double(),
500
- "station": torch.tensor(station_info_input).unsqueeze(0).double(),
501
- "target": torch.tensor(target_list).unsqueeze(0).double(),
502
- }
503
-
504
- # 5. 執行預測
505
- with torch.no_grad():
506
- weight, sigma, mu = model(tensor_data)
507
- pga_list = torch.sum(weight * mu,
508
- dim=2).cpu().detach().numpy().flatten().tolist()
509
-
510
- # 6. 繪製結果
511
- waveform_plot = plot_waveform(times, data, start_time, end_time, sampling_rate)
512
- intensity_plot = plot_intensity_map(pga_list, target_names)
513
-
514
- # 統計資訊
515
- max_intensity = max([calculate_intensity(pga, label=True) for pga in pga_list])
516
- stats = f"選取時間範圍: {start_time:.1f} - {end_time:.1f} 秒\n"
517
- stats += f"測站位置: ({lon:.4f}, {lat:.4f})\n"
518
- stats += f"預測最大震度: {max_intensity}"
519
-
520
- return waveform_plot, intensity_plot, stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
 
522
 
523
  # ============ Gradio 介面 ============
@@ -539,11 +699,12 @@ with gr.Blocks(title="TTSAM 震度預測系統") as demo:
539
  start_slider = gr.Slider(0, 300, value=0, step=1, label="起始時間 (秒)")
540
  end_slider = gr.Slider(0, 300, value=30, step=1, label="結束時間 (秒)")
541
 
542
- gr.Markdown("### 測站位置")
543
  with gr.Row():
544
- lon_input = gr.Number(value=121.5, label="經度")
545
- lat_input = gr.Number(value=24.0, label="緯度")
546
 
 
547
  with gr.Row():
548
  vs30_input = gr.Number(
549
  value=600,
@@ -553,6 +714,13 @@ with gr.Blocks(title="TTSAM 震度預測系統") as demo:
553
 
554
  predict_btn = gr.Button("🔮 執行預測", variant="primary")
555
 
 
 
 
 
 
 
 
556
  # 右側:震度分布圖
557
  with gr.Column(scale=1):
558
  gr.Markdown("## 預測震度分布")
@@ -568,7 +736,7 @@ with gr.Blocks(title="TTSAM 震度預測系統") as demo:
568
  # 綁定事件
569
  predict_btn.click(
570
  fn=predict_intensity,
571
- inputs=[event_dropdown, start_slider, end_slider, lon_input, lat_input, vs30_input],
572
  outputs=[waveform_plot, intensity_plot, stats_output]
573
  )
574
 
 
405
  # ============ Gradio 介面函數 ============
406
 
407
  def load_waveform(event_name):
408
+ """載入完整的 mseed 檔案(包含所有測站)"""
409
  file_path = EARTHQUAKE_EVENTS[event_name]
410
  st = read(file_path)
411
+ return st
 
 
 
412
 
413
 
414
+ def calculate_distance(lat1, lon1, lat2, lon2):
415
+ """計算兩點間的距離(簡化的平面距離,單位:度)"""
416
+ return np.sqrt((lat1 - lat2)**2 + (lon1 - lon2)**2)
417
 
 
 
418
 
419
+ def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25):
420
+ """選擇距離震央最近的 n 個測站"""
421
+ station_distances = []
422
 
423
+ # 計算每個測站到震央的距離
424
+ for tr in st:
425
+ station_code = tr.stats.station
426
+
427
+ # 從 site_info 中查詢測站位置
428
+ try:
429
+ station_data = site_info[site_info["Station"] == station_code]
430
+ if len(station_data) == 0:
431
+ continue
432
+
433
+ lat = station_data["Latitude"].values[0]
434
+ lon = station_data["Longitude"].values[0]
435
+ elev = station_data["Elevation"].values[0]
436
+
437
+ distance = calculate_distance(epicenter_lat, epicenter_lon, lat, lon)
438
+ station_distances.append({
439
+ "station": station_code,
440
+ "distance": distance,
441
+ "latitude": lat,
442
+ "longitude": lon,
443
+ "elevation": elev
444
+ })
445
+ except Exception as e:
446
+ logger.warning(f"測站 {station_code} 資訊查詢失敗: {e}")
447
+ continue
448
+
449
+ # 按距離排序並選擇最近的 n 個
450
+ station_distances.sort(key=lambda x: x["distance"])
451
+ selected_stations = station_distances[:n_stations]
452
+
453
+ logger.info(f"選擇了 {len(selected_stations)} 個最近的測站")
454
+ return selected_stations
455
+
456
+
457
+ def extract_waveforms_from_stream(st, selected_stations, start_time, end_time, vs30_input):
458
+ """從 Stream 中提取選定測站的波形資料"""
459
+ waveforms = []
460
+ station_info_list = []
461
+ valid_stations = []
462
+
463
+ sampling_rate = 100 # 假設 100 Hz
464
+ start_idx = int(start_time * sampling_rate)
465
+ end_idx = int(end_time * sampling_rate)
466
+ target_length = 3000
467
+
468
+ for station_data in selected_stations:
469
+ station_code = station_data["station"]
470
+
471
+ try:
472
+ # 選擇該測站的所有分量
473
+ st_station = st.select(station=station_code)
474
+
475
+ if len(st_station) == 0:
476
+ continue
477
+
478
+ # 嘗試取得 Z, N, E 分量
479
+ z_trace = st_station.select(component="Z")
480
+ n_trace = st_station.select(component="N") or st_station.select(component="1")
481
+ e_trace = st_station.select(component="E") or st_station.select(component="2")
482
+
483
+ # 如果沒有三分量,使用 Z 分量重複
484
+ if len(z_trace) > 0:
485
+ z_data = z_trace[0].data[start_idx:end_idx]
486
+ else:
487
+ continue
488
+
489
+ if len(n_trace) > 0:
490
+ n_data = n_trace[0].data[start_idx:end_idx]
491
+ else:
492
+ n_data = z_data.copy()
493
+
494
+ if len(e_trace) > 0:
495
+ e_data = e_trace[0].data[start_idx:end_idx]
496
+ else:
497
+ e_data = z_data.copy()
498
+
499
+ # 訊號處理
500
+ z_data = signal_processing(z_data)
501
+ n_data = signal_processing(n_data)
502
+ e_data = signal_processing(e_data)
503
+
504
+ # 調整長度到 3000
505
+ for data in [z_data, n_data, e_data]:
506
+ if len(data) > target_length:
507
+ data = data[:target_length]
508
+ elif len(data) < target_length:
509
+ data = np.pad(data, (0, target_length - len(data)))
510
+
511
+ # 組合三分量 (3000, 3)
512
+ waveform_3c = np.stack([z_data[:target_length],
513
+ n_data[:target_length],
514
+ e_data[:target_length]], axis=1)
515
+ waveforms.append(waveform_3c)
516
+
517
+ # 準備測站資訊
518
+ vs30 = get_vs30(station_data["latitude"], station_data["longitude"], vs30_input)
519
+ station_info_list.append([
520
+ station_data["latitude"],
521
+ station_data["longitude"],
522
+ station_data["elevation"],
523
+ vs30
524
+ ])
525
+ valid_stations.append(station_data)
526
+
527
+ except Exception as e:
528
+ logger.warning(f"測站 {station_code} 波形提取失敗: {e}")
529
+ continue
530
+
531
+ logger.info(f"成功提取 {len(waveforms)} 個測站的波形")
532
+ return waveforms, station_info_list, valid_stations
533
+
534
+
535
+ def plot_waveform(st, selected_stations, start_time, end_time):
536
+ """繪製選定測站的波形圖(堆疊顯示前 10 個測站)"""
537
+ fig, axes = plt.subplots(min(10, len(selected_stations)), 1, figsize=(12, 8), sharex=True)
538
+
539
+ if len(selected_stations) == 1:
540
+ axes = [axes]
541
+
542
+ for i, station_data in enumerate(selected_stations[:10]):
543
+ station_code = station_data["station"]
544
+
545
+ try:
546
+ st_station = st.select(station=station_code)
547
+ if len(st_station) > 0:
548
+ tr = st_station[0]
549
+ times = tr.times()
550
+ data = tr.data
551
+
552
+ axes[i].plot(times, data, 'black', linewidth=0.5)
553
+
554
+ # 標記選取範圍
555
+ axes[i].axvline(start_time, color='red', linestyle='--', linewidth=1, alpha=0.7)
556
+ axes[i].axvline(end_time, color='red', linestyle='--', linewidth=1, alpha=0.7)
557
+ axes[i].axvspan(start_time, end_time, alpha=0.2, color='blue')
558
+
559
+ axes[i].set_ylabel(f'{station_code}\n({station_data["distance"]:.2f}°)', fontsize=8)
560
+ axes[i].grid(True, alpha=0.3)
561
+ axes[i].tick_params(labelsize=8)
562
+ except Exception as e:
563
+ logger.warning(f"無法繪製測站 {station_code}: {e}")
564
+
565
+ axes[-1].set_xlabel('Time (s)')
566
+ fig.suptitle(f'波形記錄(前 10 個最近測站,共選擇 {len(selected_stations)} 個)', fontsize=12)
567
+ plt.tight_layout()
568
 
569
  return fig
570
 
 
601
  return fig
602
 
603
 
604
+ def predict_intensity(event_name, start_time, end_time, epicenter_lon, epicenter_lat, vs30_input):
605
+ """執行震度預測"""
606
+ try:
607
+ # 1. 載入完整的 mseed 檔案
608
+ logger.info(f"載入地震事件: {event_name}")
609
+ st = load_waveform(event_name)
610
+ logger.info(f"載入了 {len(st)} 個 trace")
611
 
612
+ # 2. 根據震央距離選擇最近的 25 個測站
613
+ logger.info(f"選擇距離震央 ({epicenter_lat}, {epicenter_lon}) 最近的測站...")
614
+ selected_stations = select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25)
615
+
616
+ if len(selected_stations) == 0:
617
+ return None, None, "錯誤:找不到有效的測站資料"
618
+
619
+ # 3. 從選定的測站提取波形
620
+ logger.info(f"提取波形資料(時間範圍: {start_time}-{end_time} 秒)...")
621
+ waveforms, station_info_list, valid_stations = extract_waveforms_from_stream(
622
+ st, selected_stations, start_time, end_time, vs30_input
623
+ )
624
+
625
+ if len(waveforms) == 0:
626
+ return None, None, "錯誤:無法提取波形資料"
627
+
628
+ # 4. Padding 25 個測站(模型要求)
629
+ max_stations = 25
630
+ waveform_padded = np.zeros((max_stations, 3000, 3))
631
+ station_info_padded = np.zeros((max_stations, 4))
632
+
633
+ for i in range(min(len(waveforms), max_stations)):
634
+ waveform_padded[i] = waveforms[i]
635
+ station_info_padded[i] = station_info_list[i]
636
+
637
+ # 5. 準備目標測站資訊
638
+ target_list = []
639
+ target_names = []
640
+ for target in target_dict[:25]:
641
+ target_list.append([
642
+ target["latitude"],
643
+ target["longitude"],
644
+ target["elevation"],
645
+ get_vs30(target["latitude"], target["longitude"], vs30_input)
646
+ ])
647
+ target_names.append(target["station"])
648
+
649
+ # 6. 組合成 tensor
650
+ tensor_data = {
651
+ "waveform": torch.tensor(waveform_padded).unsqueeze(0).double(),
652
+ "station": torch.tensor(station_info_padded).unsqueeze(0).double(),
653
+ "target": torch.tensor(target_list).unsqueeze(0).double(),
654
+ }
655
+
656
+ # 7. 執行預測
657
+ logger.info("執行模型預測...")
658
+ with torch.no_grad():
659
+ weight, sigma, mu = model(tensor_data)
660
+ pga_list = torch.sum(weight * mu, dim=2).cpu().detach().numpy().flatten().tolist()
661
+
662
+ # 8. 繪製結果
663
+ waveform_plot = plot_waveform(st, selected_stations, start_time, end_time)
664
+ intensity_plot = plot_intensity_map(pga_list, target_names)
665
+
666
+ # 9. 統計資訊
667
+ max_intensity = max([calculate_intensity(pga, label=True) for pga in pga_list])
668
+ stats = f"選取時間範圍: {start_time:.1f} - {end_time:.1f} 秒\n"
669
+ stats += f"震央位置: ({epicenter_lon:.4f}, {epicenter_lat:.4f})\n"
670
+ stats += f"使用測站數: {len(waveforms)} / 25\n"
671
+ stats += f"預測最大震度: {max_intensity}"
672
+
673
+ logger.info("預測完成!")
674
+ return waveform_plot, intensity_plot, stats
675
+
676
+ except Exception as e:
677
+ logger.error(f"預測過程發生錯誤: {e}")
678
+ import traceback
679
+ traceback.print_exc()
680
+ return None, None, f"錯誤: {str(e)}"
681
 
682
 
683
  # ============ Gradio 介面 ============
 
699
  start_slider = gr.Slider(0, 300, value=0, step=1, label="起始時間 (秒)")
700
  end_slider = gr.Slider(0, 300, value=30, step=1, label="結束時間 (秒)")
701
 
702
+ gr.Markdown("### 震央位置")
703
  with gr.Row():
704
+ epicenter_lon_input = gr.Number(value=121.5, label="震央經度")
705
+ epicenter_lat_input = gr.Number(value=24.0, label="震央緯度")
706
 
707
+ gr.Markdown("### 場址參數")
708
  with gr.Row():
709
  vs30_input = gr.Number(
710
  value=600,
 
714
 
715
  predict_btn = gr.Button("🔮 執行預測", variant="primary")
716
 
717
+ gr.Markdown("""
718
+ ### 說明
719
+ - 系統會根據震央位置自動選擇最近的 25 個測站
720
+ - 從選定的時間範圍提取波形資料(30 秒)
721
+ - 預測全台灣目標測站的震度分布
722
+ """)
723
+
724
  # 右側:震度分布圖
725
  with gr.Column(scale=1):
726
  gr.Markdown("## 預測震度分布")
 
736
  # 綁定事件
737
  predict_btn.click(
738
  fn=predict_intensity,
739
+ inputs=[event_dropdown, start_slider, end_slider, epicenter_lon_input, epicenter_lat_input, vs30_input],
740
  outputs=[waveform_plot, intensity_plot, stats_output]
741
  )
742