import gradio as gr import numpy as np import matplotlib.pyplot as plt from obspy import read import xarray as xr import torch import torch.nn as nn from scipy.signal import detrend, iirfilter, sosfilt, zpk2sos from scipy.spatial import cKDTree import pandas as pd from loguru import logger # 設定 matplotlib 中文字體支援 plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False # 解決負號顯示問題 # GPU/CPU 設定 if torch.cuda.is_available(): device = torch.device("cuda") logger.info("使用 GPU") else: device = torch.device("cpu") logger.info("使用 CPU") # 載入 Vs30 資料集(從 Hugging Face 下載) from huggingface_hub import hf_hub_download tree = None vs30_table = None try: logger.info("從 Hugging Face 載入 Vs30 資料...") vs30_file = hf_hub_download( repo_id="SeisBlue/TaiwanVs30", filename="Vs30ofTaiwan.nc", repo_type="dataset" ) ds = xr.open_dataset(vs30_file) lat_flat = ds['lat'].values.flatten() lon_flat = ds['lon'].values.flatten() vs30_flat = ds['vs30'].values.flatten() vs30_table = pd.DataFrame({'lat': lat_flat, 'lon': lon_flat, 'Vs30': vs30_flat}) vs30_table = vs30_table.replace([np.inf, -np.inf], np.nan).dropna() tree = cKDTree(vs30_table[["lat", "lon"]]) logger.info("Vs30 資料載入完成") except Exception as e: logger.warning(f"Vs30 資料載入失敗: {e}") logger.warning("將使用預設 Vs30 值 (600 m/s)") # 載入目標測站 target_file = "station/eew_target.csv" try: logger.info(f"載入 {target_file}...") target_df = pd.read_csv(target_file) target_dict = target_df.to_dict(orient="records") logger.info(f"{target_file} 載入完成") except FileNotFoundError: logger.error(f"{target_file} 找不到") # 載入測站資訊(輸入測站,1000+ 個) site_info_file = "station/site_info.csv" try: logger.info(f"載入 {site_info_file}...") site_info = pd.read_csv(site_info_file) # 只保留唯一的測站(去除重複的分量) site_info = site_info.drop_duplicates(subset=['Station']).reset_index(drop=True) logger.info(f"{site_info_file} 載入完成,共 {len(site_info)} 個測站") except FileNotFoundError: logger.warning(f"{site_info_file} 找不到") # 預設地震事件 EARTHQUAKE_EVENTS = { "0403花蓮地震 (2024)": "waveform/20240403.mseed", } # ============ 模型定義(從 ttsam_realtime.py 複製) ============ class LambdaLayer(nn.Module): def __init__(self, lambd, eps=1e-4): super(LambdaLayer, self).__init__() self.lambd = lambd self.eps = eps def forward(self, x): return self.lambd(x) + self.eps class MLP(nn.Module): def __init__(self, input_shape, dims=(500, 300, 200, 150), activation=nn.ReLU(), last_activation=None): super(MLP, self).__init__() if last_activation is None: last_activation = activation self.dims = dims self.first_fc = nn.Linear(input_shape[0], dims[0]) self.first_activation = activation more_hidden = [] if len(self.dims) > 2: for i in range(1, len(self.dims) - 1): more_hidden.append(nn.Linear(self.dims[i - 1], self.dims[i])) more_hidden.append(nn.ReLU()) self.more_hidden = nn.ModuleList(more_hidden) self.last_fc = nn.Linear(dims[-2], dims[-1]) self.last_activation = last_activation def forward(self, x): output = self.first_fc(x) output = self.first_activation(output) if self.more_hidden: for layer in self.more_hidden: output = layer(output) output = self.last_fc(output) output = self.last_activation(output) return output class CNN(nn.Module): def __init__(self, input_shape=(-1, 6000, 3), activation=nn.ReLU(), downsample=1, mlp_input=11665, mlp_dims=(500, 300, 200, 150), eps=1e-8): super(CNN, self).__init__() self.input_shape = input_shape self.activation = activation self.downsample = downsample self.mlp_input = mlp_input self.mlp_dims = mlp_dims self.eps = eps self.lambda_layer_1 = LambdaLayer( lambda t: t / ( torch.max(torch.max(torch.abs(t), dim=1, keepdim=True).values, dim=2, keepdim=True).values + self.eps) ) self.unsqueeze_layer1 = LambdaLayer(lambda t: torch.unsqueeze(t, dim=1)) self.lambda_layer_2 = LambdaLayer( lambda t: torch.log(torch.max(torch.max(torch.abs(t), dim=1).values, dim=1).values + self.eps) / 100 ) self.unsqueeze_layer2 = LambdaLayer(lambda t: torch.unsqueeze(t, dim=1)) self.conv2d1 = nn.Sequential( nn.Conv2d(1, 8, kernel_size=(1, downsample), stride=(1, downsample)), nn.ReLU()) self.conv2d2 = nn.Sequential( nn.Conv2d(8, 32, kernel_size=(16, 3), stride=(1, 3)), nn.ReLU()) self.conv1d1 = nn.Sequential(nn.Conv1d(32, 64, kernel_size=16), nn.ReLU()) self.maxpooling = nn.MaxPool1d(2) self.conv1d2 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=16), nn.ReLU()) self.conv1d3 = nn.Sequential(nn.Conv1d(128, 32, kernel_size=8), nn.ReLU()) self.conv1d4 = nn.Sequential(nn.Conv1d(32, 32, kernel_size=8), nn.ReLU()) self.conv1d5 = nn.Sequential(nn.Conv1d(32, 16, kernel_size=4), nn.ReLU()) self.mlp = MLP((self.mlp_input,), dims=self.mlp_dims) def forward(self, x): output = self.lambda_layer_1(x) output = self.unsqueeze_layer1(output) scale = self.lambda_layer_2(x) scale = self.unsqueeze_layer2(scale) output = self.conv2d1(output) output = self.conv2d2(output) output = torch.squeeze(output, dim=-1) output = self.conv1d1(output) output = self.maxpooling(output) output = self.conv1d2(output) output = self.maxpooling(output) output = self.conv1d3(output) output = self.maxpooling(output) output = self.conv1d4(output) output = self.conv1d5(output) output = torch.flatten(output, start_dim=1) output = torch.cat((output, scale), dim=1) output = self.mlp(output) return output class PositionEmbeddingVs30(nn.Module): def __init__(self, wavelengths=((5, 30), (110, 123), (0.01, 5000), (100, 1600)), emb_dim=500): super(PositionEmbeddingVs30, self).__init__() self.wavelengths = wavelengths self.emb_dim = emb_dim min_lat, max_lat = wavelengths[0] min_lon, max_lon = wavelengths[1] min_depth, max_depth = wavelengths[2] min_vs30, max_vs30 = wavelengths[3] assert emb_dim % 10 == 0 lat_dim = emb_dim // 5 lon_dim = emb_dim // 5 depth_dim = emb_dim // 10 vs30_dim = emb_dim // 10 self.lat_coeff = 2 * np.pi * 1.0 / min_lat * ( (min_lat / max_lat) ** (np.arange(lat_dim) / lat_dim)) self.lon_coeff = 2 * np.pi * 1.0 / min_lon * ( (min_lon / max_lon) ** (np.arange(lon_dim) / lon_dim)) self.depth_coeff = 2 * np.pi * 1.0 / min_depth * ( (min_depth / max_depth) ** (np.arange(depth_dim) / depth_dim)) self.vs30_coeff = 2 * np.pi * 1.0 / min_vs30 * ( (min_vs30 / max_vs30) ** (np.arange(vs30_dim) / vs30_dim)) lat_sin_mask = np.arange(emb_dim) % 5 == 0 lat_cos_mask = np.arange(emb_dim) % 5 == 1 lon_sin_mask = np.arange(emb_dim) % 5 == 2 lon_cos_mask = np.arange(emb_dim) % 5 == 3 depth_sin_mask = np.arange(emb_dim) % 10 == 4 depth_cos_mask = np.arange(emb_dim) % 10 == 9 vs30_sin_mask = np.arange(emb_dim) % 10 == 5 vs30_cos_mask = np.arange(emb_dim) % 10 == 8 self.mask = np.zeros(emb_dim) self.mask[lat_sin_mask] = np.arange(lat_dim) self.mask[lat_cos_mask] = lat_dim + np.arange(lat_dim) self.mask[lon_sin_mask] = 2 * lat_dim + np.arange(lon_dim) self.mask[lon_cos_mask] = 2 * lat_dim + lon_dim + np.arange(lon_dim) self.mask[depth_sin_mask] = 2 * lat_dim + 2 * lon_dim + np.arange(depth_dim) self.mask[depth_cos_mask] = 2 * lat_dim + 2 * lon_dim + depth_dim + np.arange( depth_dim) self.mask[ vs30_sin_mask] = 2 * lat_dim + 2 * lon_dim + 2 * depth_dim + np.arange( vs30_dim) self.mask[ vs30_cos_mask] = 2 * lat_dim + 2 * lon_dim + 2 * depth_dim + vs30_dim + np.arange( vs30_dim) self.mask = self.mask.astype("int32") def forward(self, x): lat_base = x[:, :, 0:1].to(device) * torch.Tensor(self.lat_coeff).to(device) lon_base = x[:, :, 1:2].to(device) * torch.Tensor(self.lon_coeff).to(device) depth_base = x[:, :, 2:3].to(device) * torch.Tensor(self.depth_coeff).to(device) vs30_base = x[:, :, 3:4] * torch.Tensor(self.vs30_coeff).to(device) output = torch.cat([ torch.sin(lat_base), torch.cos(lat_base), torch.sin(lon_base), torch.cos(lon_base), torch.sin(depth_base), torch.cos(depth_base), torch.sin(vs30_base), torch.cos(vs30_base), ], dim=-1) maskk = torch.from_numpy(np.array(self.mask)).long() index = (maskk.unsqueeze(0).unsqueeze(0)).expand(x.shape[0], 1, self.emb_dim).to(device) output = torch.gather(output, -1, index).to(device) return output class TransformerEncoder(nn.Module): def __init__(self, d_model=150, nhead=10, batch_first=True, activation="gelu", dropout=0.0, dim_feedforward=1000): super(TransformerEncoder, self).__init__() self.encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, batch_first=batch_first, activation=activation, dropout=dropout, dim_feedforward=dim_feedforward ).to(device) self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, 6).to( device) def forward(self, x, src_key_padding_mask=None): return self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask) class MDN(nn.Module): def __init__(self, input_shape=(150,), n_hidden=20, n_gaussians=5): super(MDN, self).__init__() self.z_h = nn.Sequential(nn.Linear(input_shape[0], n_hidden), nn.Tanh()) self.z_weight = nn.Linear(n_hidden, n_gaussians) self.z_sigma = nn.Linear(n_hidden, n_gaussians) self.z_mu = nn.Linear(n_hidden, n_gaussians) def forward(self, x): z_h = self.z_h(x) weight = nn.functional.softmax(self.z_weight(z_h), -1) sigma = torch.exp(self.z_sigma(z_h)) mu = self.z_mu(z_h) return weight, sigma, mu class FullModel(nn.Module): def __init__(self, model_cnn, model_position, model_transformer, model_mlp, model_mdn, max_station=25, pga_targets=15, emb_dim=150, data_length=6000): super(FullModel, self).__init__() self.data_length = data_length self.model_CNN = model_cnn self.model_Position = model_position self.model_Transformer = model_transformer self.model_mlp = model_mlp self.model_MDN = model_mdn self.max_station = max_station self.pga_targets = pga_targets self.emb_dim = emb_dim def forward(self, data): cnn_output = self.model_CNN( torch.DoubleTensor( data["waveform"].reshape(-1, self.data_length, 3)).float().to(device) ) cnn_output_reshape = torch.reshape(cnn_output, (-1, self.max_station, self.emb_dim)) emb_output = self.model_Position( torch.DoubleTensor( data["station"].reshape(-1, 1, data["station"].shape[2])).float().to( device) ) emb_output = emb_output.reshape(-1, self.max_station, self.emb_dim) station_pad_mask = data["station"] == 0 station_pad_mask = torch.all(station_pad_mask, 2) pga_pos_emb_output = self.model_Position( torch.DoubleTensor( data["target"].reshape(-1, 1, data["target"].shape[2])).float().to( device) ) pga_pos_emb_output = pga_pos_emb_output.reshape(-1, self.pga_targets, self.emb_dim) target_pad_mask = torch.ones_like(data["target"], dtype=torch.bool) target_pad_mask = torch.all(target_pad_mask, 2) pad_mask = torch.cat((station_pad_mask, target_pad_mask), dim=1).to(device) add_pe_cnn_output = torch.add(cnn_output_reshape, emb_output) transformer_input = torch.cat((add_pe_cnn_output, pga_pos_emb_output), dim=1) transformer_output = self.model_Transformer(transformer_input, pad_mask) mlp_input = transformer_output[:, -self.pga_targets:, :].to(device) mlp_output = self.model_mlp(mlp_input) weight, sigma, mu = self.model_MDN(mlp_output) return weight, sigma, mu def get_full_model(model_path): emb_dim = 150 mlp_dims = (150, 100, 50, 30, 10) cnn_model = CNN(mlp_input=5665).to(device) pos_emb_model = PositionEmbeddingVs30(emb_dim=emb_dim).to(device) transformer_model = TransformerEncoder() mlp_model = MLP(input_shape=(emb_dim,), dims=mlp_dims).to(device) mdn_model = MDN(input_shape=(mlp_dims[-1],)).to(device) full_model = FullModel( cnn_model, pos_emb_model, transformer_model, mlp_model, mdn_model, pga_targets=25, data_length=3000 ).to(device) full_model.load_state_dict( torch.load(model_path, weights_only=True, map_location=device)) return full_model # 載入模型 model_path = hf_hub_download( repo_id="SeisBlue/TTSAM", filename="ttsam_trained_model_11.pt" ) model = get_full_model(model_path) # ============ 輔助函數 ============ def lowpass(data, freq=10, df=100, corners=4): fe = 0.5 * df f = freq / fe if f > 1: f = 1.0 z, p, k = iirfilter(corners, f, btype="lowpass", ftype="butter", output="zpk") sos = zpk2sos(z, p, k) return sosfilt(sos, data) def signal_processing(waveform): data = detrend(waveform, type="constant") data = lowpass(data, freq=10) return data def get_vs30(lat, lon, user_vs30=600): if tree is None or vs30_table is None: # 如果 Vs30 資料未載入,使用使用者輸入的值 logger.info(f"使用使用者輸入的 Vs30 值 ({user_vs30} m/s) for ({lat}, {lon})") return float(user_vs30) distance, i = tree.query([float(lat), float(lon)]) vs30 = vs30_table.iloc[i]["Vs30"] logger.info(f"從資料庫查詢到 Vs30 值 ({vs30} m/s) for ({lat}, {lon})") return float(vs30) def calculate_intensity(pga, label=False): intensity_label = ["0", "1", "2", "3", "4", "5-", "5+", "6-", "6+", "7"] pga_level = np.log10([1e-5, 0.008, 0.025, 0.080, 0.250, 0.80, 1.4, 2.5, 4.4, 8.0]) pga_intensity = np.searchsorted(pga_level, pga) - 1 intensity = pga_intensity if label: return intensity_label[intensity] else: return intensity # ============ Gradio 介面函數 ============ def load_waveform(event_name): """載入完整的 mseed 檔案(包含所有測站)""" file_path = EARTHQUAKE_EVENTS[event_name] st = read(file_path) return st def calculate_distance(lat1, lon1, lat2, lon2): """計算兩點間的距離(簡化的平面距離,單位:度)""" return np.sqrt((lat1 - lat2)**2 + (lon1 - lon2)**2) def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25): """從 site_info(1000+ 個輸入測站)中選擇距離震央最近的 n 個測站""" station_distances = {} # 改用字典避免重複 # 計算每個測站到震央的距離 for tr in st: station_code = tr.stats.station # 如果這個測站已經處理過,跳過(避免重複計算不同分量) if station_code in station_distances: continue # 從 site_info 中查詢測站位置 try: station_data = site_info[site_info["Station"] == station_code] if len(station_data) == 0: continue lat = station_data["Latitude"].values[0] lon = station_data["Longitude"].values[0] elev = station_data["Elevation"].values[0] distance = calculate_distance(epicenter_lat, epicenter_lon, lat, lon) station_distances[station_code] = { "station": station_code, "distance": distance, "latitude": lat, "longitude": lon, "elevation": elev } except Exception as e: logger.warning(f"測站 {station_code} 資訊查詢失敗: {e}") continue # 轉換為列表並按距離排序,選擇最近的 n 個 station_list = list(station_distances.values()) station_list.sort(key=lambda x: x["distance"]) selected_stations = station_list[:n_stations] logger.info(f"從 {len(station_list)} 個輸入測站中選擇了最近的 {len(selected_stations)} 個") return selected_stations def extract_waveforms_from_stream(st, selected_stations, start_time, end_time, vs30_input): """從 Stream 中提取選定測站的波形資料""" waveforms = [] station_info_list = [] valid_stations = [] sampling_rate = 100 # 假設 100 Hz start_idx = int(start_time * sampling_rate) end_idx = int(end_time * sampling_rate) target_length = 3000 for station_data in selected_stations: station_code = station_data["station"] try: # 選擇該測站的所有分量 st_station = st.select(station=station_code) if len(st_station) == 0: continue # 嘗試取得 Z, N, E 分量 z_trace = st_station.select(component="Z") n_trace = st_station.select(component="N") or st_station.select(component="1") e_trace = st_station.select(component="E") or st_station.select(component="2") # 如果沒有三分量,使用 Z 分量重複 if len(z_trace) > 0: z_data = z_trace[0].data[start_idx:end_idx] else: continue if len(n_trace) > 0: n_data = n_trace[0].data[start_idx:end_idx] else: n_data = z_data.copy() if len(e_trace) > 0: e_data = e_trace[0].data[start_idx:end_idx] else: e_data = z_data.copy() # 訊號處理 z_data = signal_processing(z_data) n_data = signal_processing(n_data) e_data = signal_processing(e_data) # 先創建全零陣列 (3000, 3) waveform_3c = np.zeros((target_length, 3)) # 填入實際資料(自動處理長度不足或過長的情況) z_len = min(len(z_data), target_length) n_len = min(len(n_data), target_length) e_len = min(len(e_data), target_length) waveform_3c[:z_len, 0] = z_data[:z_len] waveform_3c[:n_len, 1] = n_data[:n_len] waveform_3c[:e_len, 2] = e_data[:e_len] waveforms.append(waveform_3c) # 準備測站資訊 vs30 = get_vs30(station_data["latitude"], station_data["longitude"], vs30_input) station_info_list.append([ station_data["latitude"], station_data["longitude"], station_data["elevation"], vs30 ]) valid_stations.append(station_data) except Exception as e: logger.warning(f"測站 {station_code} 波形提取失敗: {e}") continue logger.info(f"成功提取 {len(waveforms)} 個測站的波形") return waveforms, station_info_list, valid_stations def plot_waveform(st, selected_stations, start_time, end_time): """繪製選定測站的波形圖(距離-時間圖,可顯示全部 25 個測站)""" fig, ax = plt.subplots(figsize=(14, 10)) # 設定振幅縮放比例(避免波形重疊) amplitude_scale = 0.03 # 可調整此值來控制波形大小 plotted_count = 0 distances = [] station_names = [] for i, station_data in enumerate(selected_stations): station_code = station_data["station"] distance = station_data["distance"] try: st_station = st.select(station=station_code) if len(st_station) > 0: tr = st_station[0] times = tr.times() data = tr.data # 正規化波形振幅 data_normalized = data / (np.max(np.abs(data)) + 1e-10) # 繪製波形,Y軸位置為距離 ax.plot(times, distance + data_normalized * amplitude_scale, 'black', linewidth=0.3, alpha=0.8) distances.append(distance) station_names.append(station_code) plotted_count += 1 except Exception as e: logger.warning(f"無法繪製測站 {station_code}: {e}") # 標記選取時間範圍 ax.axvline(start_time, color='red', linestyle='--', linewidth=2, alpha=0.7, label='選取範圍') ax.axvline(end_time, color='red', linestyle='--', linewidth=2, alpha=0.7) ax.axvspan(start_time, end_time, alpha=0.15, color='blue') # 設定軸標籤和標題 ax.set_xlabel('Time (s)', fontsize=12) ax.set_ylabel('Distance from Epicenter (°)', fontsize=12) ax.set_title(f'Record Section - {plotted_count} Stations Sorted by Distance', fontsize=14, fontweight='bold') # 在右側標註測站名稱 if distances: ax2 = ax.twinx() ax2.set_ylim(ax.get_ylim()) ax2.set_ylabel('Station Code', fontsize=12) # 每隔幾個測站標註一次(避免過於擁擠) step = max(1, len(distances) // 10) tick_positions = distances[::step] tick_labels = station_names[::step] ax2.set_yticks(tick_positions) ax2.set_yticklabels(tick_labels, fontsize=8) ax.grid(True, alpha=0.3, axis='x') ax.legend(loc='upper right') plt.tight_layout() return fig def get_intensity_color(intensity): """根據震度等級返回對應顏色(參考 intensityMap.html)""" color_map = { 0: "#ffffff", # 白色 1: "#33FFDD", # 青色 2: "#34ff32", # 綠色 3: "#fefd32", # 黃色 4: "#fe8532", # 橘色 5: "#fd5233", # 紅橘色 (5-) 6: "#c43f3b", # 深紅色 (5+) 7: "#9d4646", # 暗紅色 (6-) 8: "#9a4c86", # 紫紅色 (6+) 9: "#b51fea", # 紫色 (7) } return color_map.get(intensity, "#ffffff") def create_intensity_map(pga_list, target_names, epicenter_lat=None, epicenter_lon=None): """使用 Folium 創建互動式震度分布地圖""" import folium from folium import plugins # 創建地圖,中心點設在台灣中心,設定地圖尺寸 m = folium.Map( location=[23.5, 121], zoom_start=7, tiles='OpenStreetMap', width='100%', height='600px' # 設定固定高度,與 Ground Truth 圖片匹配 ) # 如果有震央位置,標記震央 if epicenter_lat and epicenter_lon: folium.Marker( [epicenter_lat, epicenter_lon], popup=f'震央
({epicenter_lat:.3f}, {epicenter_lon:.3f})', icon=folium.Icon(color='red', icon='star', prefix='fa'), tooltip='震央位置' ).add_to(m) # 添加震度測站標記 for i, target_name in enumerate(target_names): target = next((t for t in target_dict if t["station"] == target_name), None) if target: lat = target["latitude"] lon = target["longitude"] intensity = calculate_intensity(pga_list[i]) intensity_label = calculate_intensity(pga_list[i], label=True) color = get_intensity_color(intensity) pga = pga_list[i] # 創建 HTML popup 內容 popup_html = f"""

{target_name}

震度:{intensity_label}
PGA:{pga:.4f} m/s²
位置:({lat:.3f}, {lon:.3f})
""" # 創建圓形標記 folium.CircleMarker( location=[lat, lon], radius=12, popup=folium.Popup(popup_html, max_width=250), tooltip=f'{target_name}: 震度 {intensity_label}', color='black', fillColor=color, fillOpacity=0.8, weight=2 ).add_to(m) # 在圓圈中心添加震度文字 folium.Marker( [lat, lon], icon=folium.DivIcon(html=f'''
{intensity_label}
''') ).add_to(m) # 添加圖例 legend_html = '''

震度等級 Intensity

''' intensity_levels = ["0", "1", "2", "3", "4", "5-", "5+", "6-", "6+", "7"] for idx, level in enumerate(intensity_levels): color = get_intensity_color(idx) legend_html += f''' ''' legend_html += '''
{level}
''' m.get_root().html.add_child(folium.Element(legend_html)) # 添加全屏按鈕 plugins.Fullscreen().add_to(m) return m def load_ground_truth_image(event_name): """從 ground_truth 資料夾載入對應的 Ground Truth 圖片""" import os # 根據事件名稱找對應的圖片 # 假設圖片命名格式為:20240403.png 或類似 event_file = EARTHQUAKE_EVENTS[event_name] event_date = os.path.basename(event_file).replace('.mseed', '') # 嘗試不同的圖片格式 ground_truth_dir = "ground_truth" possible_extensions = ['.png', '.jpg', '.jpeg', '.gif'] for ext in possible_extensions: image_path = os.path.join(ground_truth_dir, f"{event_date}{ext}") if os.path.exists(image_path): logger.info(f"載入 Ground Truth 圖片: {image_path}") return image_path logger.warning(f"找不到 Ground Truth 圖片: {event_date}") return None def create_input_station_map(selected_stations, epicenter_lat, epicenter_lon): """創建輸入測站分布地圖:顯示所有測站 + 突顯被選中的 25 個""" import folium from folium import plugins # 創建地圖,中心點設在震央 m = folium.Map( location=[epicenter_lat, epicenter_lon], zoom_start=8, tiles='OpenStreetMap', width='100%', height='500px' ) # 建立被選中測站的 set(用於快速查詢) selected_station_codes = {s["station"] for s in selected_stations} # 1. 先繪製所有測站(灰色小點) logger.info(f"繪製所有測站 ({len(site_info)} 個)...") for idx, row in site_info.iterrows(): station_code = row["Station"] lat = row["Latitude"] lon = row["Longitude"] # 跳過被選中的測站(稍後用不同樣式繪製) if station_code in selected_station_codes: continue folium.CircleMarker( location=[lat, lon], radius=2, popup=f'{station_code}', tooltip=station_code, color='gray', fillColor='lightgray', fillOpacity=0.4, weight=1 ).add_to(m) # 2. 標記震央(紅色星星) folium.Marker( [epicenter_lat, epicenter_lon], popup=f'震央
({epicenter_lat:.3f}, {epicenter_lon:.3f})', icon=folium.Icon(color='red', icon='star', prefix='fa'), tooltip='震央位置', zIndexOffset=1000 ).add_to(m) # 3. 標記被選中的 25 個測站(彩色大點) for i, station_data in enumerate(selected_stations): station_code = station_data["station"] lat = station_data["latitude"] lon = station_data["longitude"] distance = station_data["distance"] # 創建 popup 內容 popup_html = f"""

{station_code}

狀態:✓ 已選中
順序:第 {i+1} 近
距離:{distance:.2f}°
位置:({lat:.3f}, {lon:.3f})
""" # 根據距離設定顏色 if i < 5: color = 'green' elif i < 15: color = 'blue' else: color = 'orange' folium.CircleMarker( location=[lat, lon], radius=10, popup=folium.Popup(popup_html, max_width=250), tooltip=f'✓ {station_code} (第{i+1}近)', color='black', fillColor=color, fillOpacity=0.8, weight=2, zIndexOffset=500 ).add_to(m) # 4. 添加圖例 total_stations = len(site_info) legend_html = f'''

測站分布

震央

所有測站 ({total_stations} 個)


被選中的測站:

前 5 近

6-15 近

16-25 近

共選擇 {len(selected_stations)} 個測站

''' m.get_root().html.add_child(folium.Element(legend_html)) # 5. 添加全屏按鈕 plugins.Fullscreen().add_to(m) return m def load_and_display_waveform(event_name, start_time, end_time, epicenter_lon, epicenter_lat): """載入並顯示波形,讓使用者確認範圍""" try: # 1. 載入完整的 mseed 檔案 logger.info(f"載入地震事件: {event_name}") st = load_waveform(event_name) logger.info(f"載入了 {len(st)} 個 trace") # 2. 根據震央距離選擇最近的 25 個測站 logger.info(f"選擇距離震央 ({epicenter_lat}, {epicenter_lon}) 最近的測站...") selected_stations = select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25) if len(selected_stations) == 0: return None, "錯誤:找不到有效的測站資料", gr.update(interactive=False) # 3. 繪製波形 waveform_plot = plot_waveform(st, selected_stations, start_time, end_time) # 4. 創建輸入測站地圖 station_map = create_input_station_map(selected_stations, epicenter_lat, epicenter_lon) station_map_html = station_map._repr_html_() info_text = f"✅ 已載入波形資料\n" info_text += f"選取時間範圍: {start_time:.1f} - {end_time:.1f} 秒\n" info_text += f"震央位置: ({epicenter_lon:.4f}, {epicenter_lat:.4f})\n" info_text += f"選擇了 {len(selected_stations)} 個最近的測站\n" info_text += f"請確認波形範圍後,點擊「執行預測」按鈕" logger.info("波形載入完成") return station_map_html, waveform_plot, info_text, gr.update(interactive=True) except Exception as e: logger.error(f"波形載入發生錯誤: {e}") import traceback traceback.print_exc() return None, None, f"錯誤: {str(e)}", gr.update(interactive=False) def predict_intensity(event_name, start_time, end_time, epicenter_lon, epicenter_lat): """執行震度預測""" try: # 1. 載入完整的 mseed 檔案 logger.info(f"載入地震事件: {event_name}") st = load_waveform(event_name) logger.info(f"載入了 {len(st)} 個 trace") # 2. 根據震央距離選擇最近的 25 個測站 logger.info(f"選擇距離震央 ({epicenter_lat}, {epicenter_lon}) 最近的測站...") selected_stations = select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25) if len(selected_stations) == 0: return None, None, "錯誤:找不到有效的測站資料" # 3. 從選定的測站提取波形(vs30_input 使用預設值 600,會被資料庫值覆蓋) logger.info(f"提取波形資料(時間範圍: {start_time}-{end_time} 秒)...") waveforms, station_info_list, valid_stations = extract_waveforms_from_stream( st, selected_stations, start_time, end_time, vs30_input=600 ) if len(waveforms) == 0: return None, "錯誤:無法提取波形資料" # 4. Padding 到 25 個測站(模型要求) max_stations = 25 waveform_padded = np.zeros((max_stations, 3000, 3)) station_info_padded = np.zeros((max_stations, 4)) for i in range(min(len(waveforms), max_stations)): waveform_padded[i] = waveforms[i] station_info_padded[i] = station_info_list[i] # 5. 準備所有目標測站資訊(分批處理) all_pga_list = [] all_target_names = [] # 計算需要分幾批(每批 25 個測站) batch_size = 25 total_targets = len(target_dict) num_batches = (total_targets + batch_size - 1) // batch_size logger.info(f"開始分批預測 {total_targets} 個目標測站(共 {num_batches} 批)...") for batch_idx in range(num_batches): start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, total_targets) batch_targets = target_dict[start_idx:end_idx] logger.info(f"預測第 {batch_idx + 1}/{num_batches} 批(測站 {start_idx + 1}-{end_idx})...") # 準備這批目標測站資訊 target_list = [] target_names = [] for target in batch_targets: target_list.append([ target["latitude"], target["longitude"], target["elevation"], get_vs30(target["latitude"], target["longitude"], user_vs30=600) ]) target_names.append(target["station"]) # Padding 到 25 個(如果不足 25 個) target_padded = np.zeros((batch_size, 4)) for i in range(len(target_list)): target_padded[i] = target_list[i] # 6. 組合成 tensor tensor_data = { "waveform": torch.tensor(waveform_padded).unsqueeze(0).double(), "station": torch.tensor(station_info_padded).unsqueeze(0).double(), "target": torch.tensor(target_padded).unsqueeze(0).double(), } # 7. 執行預測 with torch.no_grad(): weight, sigma, mu = model(tensor_data) batch_pga = torch.sum(weight * mu, dim=2).cpu().detach().numpy().flatten().tolist() # 只取實際有資料的部分 all_pga_list.extend(batch_pga[:len(target_names)]) all_target_names.extend(target_names) logger.info(f"完成所有 {len(all_target_names)} 個測站的預測!") pga_list = all_pga_list target_names = all_target_names # 8. 繪製互動式地圖 intensity_map = create_intensity_map(pga_list, target_names, epicenter_lat, epicenter_lon) map_html = intensity_map._repr_html_() # 9. 載入 Ground Truth 圖片 ground_truth_path = load_ground_truth_image(event_name) # 10. 統計資訊 max_intensity = max([calculate_intensity(pga, label=True) for pga in pga_list]) stats = f"✅ 預測完成!\n" stats += f"選取時間範圍: {start_time:.1f} - {end_time:.1f} 秒\n" stats += f"震央位置: ({epicenter_lon:.4f}, {epicenter_lat:.4f})\n" stats += f"使用測站數: {len(waveforms)} / 25\n" stats += f"預測最大震度: {max_intensity}" logger.info("預測完成!") return ground_truth_path, map_html, stats except Exception as e: logger.error(f"預測過程發生錯誤: {e}") import traceback traceback.print_exc() return None, None, f"錯誤: {str(e)}" # ============ Gradio 介面 ============ with gr.Blocks(title="TTSAM 震度預測系統") as demo: gr.Markdown("# 🌏 TTSAM 震度預測系統") # ========== 上層:使用說明與參數設定 ========== with gr.Row(): # 左上:使用步驟與狀態顯示 with gr.Column(scale=1): gr.Markdown("## 使用步驟") gr.Markdown(""" 1. 選擇地震事件和時間範圍 2. 輸入震央位置和場址參數 3. 點擊「載入波形」確認波形範圍 4. 確認無誤後,點擊「執行預測」 ℹ️ 系統會自動選擇距離震央最近的 25 個測站 """) info_output = gr.Textbox(label="狀態資訊", lines=6, interactive=False) stats_output = gr.Textbox(label="預測統計", lines=4, interactive=False) # 右上:輸入參數 with gr.Column(scale=1): gr.Markdown("## 輸入參數") event_dropdown = gr.Dropdown( choices=list(EARTHQUAKE_EVENTS.keys()), value=list(EARTHQUAKE_EVENTS.keys())[0], label="選擇地震事件" ) with gr.Row(): start_slider = gr.Slider(0, 300, value=0, step=1, label="起始時間 (秒)") end_slider = gr.Slider(0, 300, value=30, step=1, label="結束時間 (秒)") gr.Markdown("### 震央位置") with gr.Row(): epicenter_lon_input = gr.Number(value=121.57, label="震央經度") epicenter_lat_input = gr.Number(value=23.88, label="震央緯度") with gr.Row(): load_waveform_btn = gr.Button("📊 載入波形", variant="secondary", scale=1) predict_btn = gr.Button("🔮 執行預測", variant="primary", scale=1, interactive=False) # ========== 中層:輸入測站地圖與波形圖 ========== with gr.Row(): # 中左:輸入波形 with gr.Column(scale=1): gr.Markdown("## 輸入波形") waveform_plot = gr.Plot(label="地震波形(選定的 25 個測站)") # 中右:輸入測站地圖 with gr.Column(scale=1): gr.Markdown("## 輸入測站分布") input_station_map = gr.HTML(label="輸入測站地圖") # ========== 下層:Ground Truth vs 預測結果 ========== with gr.Row(): # 左下:Ground Truth with gr.Column(scale=1): gr.Markdown("## Ground Truth 震度分布") ground_truth_image = gr.Image(label="實際觀測震度", type="filepath", height=600) # 右下:預測震度地圖 with gr.Column(scale=1): gr.Markdown("## 預測震度分布") intensity_map = gr.HTML(label="互動式震度地圖", elem_id="intensity_map") # 綁定事件 # 第一步:載入波形 load_waveform_btn.click( fn=load_and_display_waveform, inputs=[event_dropdown, start_slider, end_slider, epicenter_lon_input, epicenter_lat_input], outputs=[input_station_map, waveform_plot, info_output, predict_btn] ) # 第二步:執行預測 predict_btn.click( fn=predict_intensity, inputs=[event_dropdown, start_slider, end_slider, epicenter_lon_input, epicenter_lat_input], outputs=[ground_truth_image, intensity_map, stats_output] ) demo.launch()