jimmy60504 commited on
Commit
118c2ea
·
1 Parent(s): 200c9b4

feat: transition from matplotlib to Plotly for waveform visualization and enhance interactivity

Browse files
Files changed (1) hide show
  1. app.py +88 -56
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import matplotlib.pyplot as plt
3
  import numpy as np
4
  import pandas as pd
5
  import plotly.graph_objs as go
@@ -14,10 +13,6 @@ from scipy.spatial import cKDTree
14
 
15
  from model import get_full_model
16
 
17
- # 設定 matplotlib 中文字體支援
18
- plt.rcParams["font.sans-serif"] = ["Arial Unicode MS", "DejaVu Sans"]
19
- plt.rcParams["axes.unicode_minus"] = False # 解決負號顯示問題
20
-
21
  tree = None
22
  vs30_table = None
23
 
@@ -564,7 +559,8 @@ def plot_waveform(st, selected_stations, first_pick, duration):
564
  # 計算結束時間
565
  end_time = first_pick + duration
566
 
567
- fig, ax = plt.subplots(figsize=(14, 4))
 
568
 
569
  # 設定振幅縮放比例(避免波形重疊)
570
  amplitude_scale = 0.03 # 可調整此值來控制波形大小
@@ -596,22 +592,27 @@ def plot_waveform(st, selected_stations, first_pick, duration):
596
  data_normalized = data / (np.max(np.abs(data)) + 1e-10)
597
 
598
  # 繪製波形,Y軸位置為距離
599
- ax.plot(
600
- times,
601
- distance + data_normalized * amplitude_scale,
602
- "black",
603
- linewidth=0.3,
604
- alpha=0.8,
605
- )
 
 
 
 
 
606
 
607
  # 記錄 P 波標記位置
608
  if p_arrival_time is not None:
609
  if 0 <= p_arrival_time <= end_time:
610
  # P 波在時間窗內(綠色)
611
- p_wave_markers_in.append((p_arrival_time, distance))
612
  else:
613
  # P 波在時間窗外(紅色)
614
- p_wave_markers_out.append((p_arrival_time, distance))
615
 
616
  distances.append(distance)
617
  station_names.append(station_code)
@@ -622,55 +623,86 @@ def plot_waveform(st, selected_stations, first_pick, duration):
622
 
623
  # 繪製 P 波標記
624
  if p_wave_markers_in:
625
- p_times_in, p_dists_in = zip(*p_wave_markers_in)
626
- ax.scatter(p_times_in, p_dists_in, color="green", marker="v", s=50,
627
- zorder=5, label="P-wave (in window)", alpha=0.7)
 
 
 
 
 
 
 
 
628
 
629
  if p_wave_markers_out:
630
- p_times_out, p_dists_out = zip(*p_wave_markers_out)
631
- ax.scatter(p_times_out, p_dists_out, color="red", marker="v", s=50,
632
- zorder=5, label="P-wave (out window)", alpha=0.7)
633
-
634
- ax.axvline(first_pick, color="blue", linestyle="--", linewidth=2, alpha=0.7,
635
- label="First Motion")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
  # 標記選取時間範圍
638
- ax.axvline(
639
- 0,
640
- color="red",
641
- linestyle="--",
642
- linewidth=2,
643
- alpha=0.7,
644
- label="Input Waveform",
645
  )
646
- ax.axvline(end_time, color="red", linestyle="--", linewidth=2, alpha=0.7)
647
- ax.axvspan(0, end_time, alpha=0.15, color="blue")
648
 
649
- # 設定軸標籤和標題
650
- ax.set_xlabel("Time (s)", fontsize=12)
651
- ax.set_ylabel("Distance from Epicenter (°)", fontsize=12)
652
- ax.set_title(
653
- f"Record Section - {plotted_count} Stations Sorted by Distance",
654
- fontsize=14,
655
- fontweight="bold",
656
  )
657
 
658
- # 在右側標註測站名稱
659
- if distances:
660
- ax2 = ax.twinx()
661
- ax2.set_ylim(ax.get_ylim())
662
- ax2.set_ylabel("Station Code", fontsize=12)
663
-
664
- # 每隔幾個測站標註一次(避免過於擁擠)
665
- step = max(1, len(distances) // 10)
666
- tick_positions = distances[::step]
667
- tick_labels = station_names[::step]
668
- ax2.set_yticks(tick_positions)
669
- ax2.set_yticklabels(tick_labels, fontsize=8)
670
 
671
- ax.grid(True, alpha=0.3, axis="x")
672
- ax.legend(loc="upper right")
673
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
 
675
  return fig
676
 
@@ -1149,7 +1181,7 @@ with gr.Blocks(title="TTSAM 震度預測系統", fill_height=True) as demo:
1149
  )
1150
 
1151
  waveform_plot = gr.Plot(
1152
- label="地震波形(選定的 25 個測站)",
1153
  )
1154
 
1155
  # ========== 下層:合併地圖 vs 實際觀測 ==========
 
1
  import gradio as gr
 
2
  import numpy as np
3
  import pandas as pd
4
  import plotly.graph_objs as go
 
13
 
14
  from model import get_full_model
15
 
 
 
 
 
16
  tree = None
17
  vs30_table = None
18
 
 
559
  # 計算結束時間
560
  end_time = first_pick + duration
561
 
562
+ # 創建 Plotly figure
563
+ fig = go.Figure()
564
 
565
  # 設定振幅縮放比例(避免波形重疊)
566
  amplitude_scale = 0.03 # 可調整此值來控制波形大小
 
592
  data_normalized = data / (np.max(np.abs(data)) + 1e-10)
593
 
594
  # 繪製波形,Y軸位置為距離
595
+ y_values = distance + data_normalized * amplitude_scale
596
+
597
+ fig.add_trace(go.Scatter(
598
+ x=times,
599
+ y=y_values,
600
+ mode='lines',
601
+ line=dict(color='black', width=0.5),
602
+ opacity=0.8,
603
+ name=station_code,
604
+ hovertemplate=f'{station_code}<br>Time: %{{x:.2f}}s<br>Distance: {distance:.3f}°<extra></extra>',
605
+ showlegend=False
606
+ ))
607
 
608
  # 記錄 P 波標記位置
609
  if p_arrival_time is not None:
610
  if 0 <= p_arrival_time <= end_time:
611
  # P 波在時間窗內(綠色)
612
+ p_wave_markers_in.append((p_arrival_time, distance, station_code))
613
  else:
614
  # P 波在時間窗外(紅色)
615
+ p_wave_markers_out.append((p_arrival_time, distance, station_code))
616
 
617
  distances.append(distance)
618
  station_names.append(station_code)
 
623
 
624
  # 繪製 P 波標記
625
  if p_wave_markers_in:
626
+ p_times_in, p_dists_in, p_names_in = zip(*p_wave_markers_in)
627
+ fig.add_trace(go.Scatter(
628
+ x=p_times_in,
629
+ y=p_dists_in,
630
+ mode='markers',
631
+ marker=dict(color='green', size=8, symbol='triangle-down'),
632
+ name='P-wave (in window)',
633
+ hovertemplate='P-wave<br>Station: %{text}<br>Time: %{x:.2f}s<extra></extra>',
634
+ text=p_names_in,
635
+ showlegend=True
636
+ ))
637
 
638
  if p_wave_markers_out:
639
+ p_times_out, p_dists_out, p_names_out = zip(*p_wave_markers_out)
640
+ fig.add_trace(go.Scatter(
641
+ x=p_times_out,
642
+ y=p_dists_out,
643
+ mode='markers',
644
+ marker=dict(color='red', size=8, symbol='triangle-down'),
645
+ name='P-wave (out window)',
646
+ hovertemplate='P-wave<br>Station: %{text}<br>Time: %{x:.2f}s<extra></extra>',
647
+ text=p_names_out,
648
+ showlegend=True
649
+ ))
650
+
651
+ # 添加垂直線標記
652
+ # First Motion
653
+ fig.add_vline(
654
+ x=first_pick,
655
+ line=dict(color='blue', dash='dash', width=2),
656
+ annotation_text='First Motion',
657
+ annotation_position='top',
658
+ opacity=0.7
659
+ )
660
 
661
  # 標記選取時間範圍
662
+ fig.add_vline(
663
+ x=0,
664
+ line=dict(color='red', dash='dash', width=2),
665
+ opacity=0.7
 
 
 
666
  )
 
 
667
 
668
+ fig.add_vline(
669
+ x=end_time,
670
+ line=dict(color='red', dash='dash', width=2),
671
+ opacity=0.7
 
 
 
672
  )
673
 
674
+ # 添加時間窗陰影
675
+ fig.add_vrect(
676
+ x0=0, x1=end_time,
677
+ fillcolor='blue', opacity=0.1,
678
+ layer='below', line_width=0,
679
+ )
 
 
 
 
 
 
680
 
681
+ # 設定軸標籤和標題
682
+ fig.update_layout(
683
+ xaxis=dict(
684
+ title=dict(text='Time (s)', font=dict(size=12)),
685
+ gridcolor='rgba(128, 128, 128, 0.2)',
686
+ showgrid=True,
687
+ ),
688
+ yaxis=dict(
689
+ title=dict(text='Distance (°)', font=dict(size=12)),
690
+ gridcolor='rgba(128, 128, 128, 0.2)',
691
+ showgrid=False
692
+ ),
693
+ hovermode='closest',
694
+ height=200,
695
+ plot_bgcolor='white',
696
+ margin=dict(l=0, r=10, t=30, b=0), # 緊凑的邊距設置
697
+ showlegend=True,
698
+ legend=dict(
699
+ yanchor="top",
700
+ y=0.99,
701
+ xanchor="right",
702
+ x=0.99,
703
+ bgcolor="rgba(255, 255, 255, 0.8)",
704
+ )
705
+ )
706
 
707
  return fig
708
 
 
1181
  )
1182
 
1183
  waveform_plot = gr.Plot(
1184
+ label="地震波形",
1185
  )
1186
 
1187
  # ========== 下層:合併地圖 vs 實際觀測 ==========