Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from obspy import read | |
| import xarray as xr | |
| import torch | |
| import torch.nn as nn | |
| from scipy.signal import detrend, iirfilter, sosfilt, zpk2sos | |
| from scipy.spatial import cKDTree | |
| import pandas as pd | |
| from loguru import logger | |
| # 設定 matplotlib 中文字體支援 | |
| plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'DejaVu Sans'] | |
| plt.rcParams['axes.unicode_minus'] = False # 解決負號顯示問題 | |
| # GPU/CPU 設定 | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| logger.info("使用 GPU") | |
| else: | |
| device = torch.device("cpu") | |
| logger.info("使用 CPU") | |
| # 載入 Vs30 資料集(從 Hugging Face 下載) | |
| from huggingface_hub import hf_hub_download | |
| tree = None | |
| vs30_table = None | |
| try: | |
| logger.info("從 Hugging Face 載入 Vs30 資料...") | |
| vs30_file = hf_hub_download( | |
| repo_id="SeisBlue/TaiwanVs30", | |
| filename="Vs30ofTaiwan.nc", | |
| repo_type="dataset" | |
| ) | |
| ds = xr.open_dataset(vs30_file) | |
| lat_flat = ds['lat'].values.flatten() | |
| lon_flat = ds['lon'].values.flatten() | |
| vs30_flat = ds['vs30'].values.flatten() | |
| vs30_table = pd.DataFrame({'lat': lat_flat, 'lon': lon_flat, 'Vs30': vs30_flat}) | |
| vs30_table = vs30_table.replace([np.inf, -np.inf], np.nan).dropna() | |
| tree = cKDTree(vs30_table[["lat", "lon"]]) | |
| logger.info("Vs30 資料載入完成") | |
| except Exception as e: | |
| logger.warning(f"Vs30 資料載入失敗: {e}") | |
| logger.warning("將使用預設 Vs30 值 (600 m/s)") | |
| # 載入目標測站 | |
| target_file = "station/eew_target.csv" | |
| try: | |
| logger.info(f"載入 {target_file}...") | |
| target_df = pd.read_csv(target_file) | |
| target_dict = target_df.to_dict(orient="records") | |
| logger.info(f"{target_file} 載入完成") | |
| except FileNotFoundError: | |
| logger.error(f"{target_file} 找不到") | |
| # 載入測站資訊(輸入測站,1000+ 個) | |
| site_info_file = "station/site_info.csv" | |
| try: | |
| logger.info(f"載入 {site_info_file}...") | |
| site_info = pd.read_csv(site_info_file) | |
| # 只保留唯一的測站(去除重複的分量) | |
| site_info = site_info.drop_duplicates(subset=['Station']).reset_index(drop=True) | |
| logger.info(f"{site_info_file} 載入完成,共 {len(site_info)} 個測站") | |
| except FileNotFoundError: | |
| logger.warning(f"{site_info_file} 找不到") | |
| # 預設地震事件 | |
| EARTHQUAKE_EVENTS = { | |
| "0403花蓮地震 (2024)": "waveform/20240403.mseed", | |
| } | |
| # ============ 模型定義(從 ttsam_realtime.py 複製) ============ | |
| class LambdaLayer(nn.Module): | |
| def __init__(self, lambd, eps=1e-4): | |
| super(LambdaLayer, self).__init__() | |
| self.lambd = lambd | |
| self.eps = eps | |
| def forward(self, x): | |
| return self.lambd(x) + self.eps | |
| class MLP(nn.Module): | |
| def __init__(self, input_shape, dims=(500, 300, 200, 150), activation=nn.ReLU(), | |
| last_activation=None): | |
| super(MLP, self).__init__() | |
| if last_activation is None: | |
| last_activation = activation | |
| self.dims = dims | |
| self.first_fc = nn.Linear(input_shape[0], dims[0]) | |
| self.first_activation = activation | |
| more_hidden = [] | |
| if len(self.dims) > 2: | |
| for i in range(1, len(self.dims) - 1): | |
| more_hidden.append(nn.Linear(self.dims[i - 1], self.dims[i])) | |
| more_hidden.append(nn.ReLU()) | |
| self.more_hidden = nn.ModuleList(more_hidden) | |
| self.last_fc = nn.Linear(dims[-2], dims[-1]) | |
| self.last_activation = last_activation | |
| def forward(self, x): | |
| output = self.first_fc(x) | |
| output = self.first_activation(output) | |
| if self.more_hidden: | |
| for layer in self.more_hidden: | |
| output = layer(output) | |
| output = self.last_fc(output) | |
| output = self.last_activation(output) | |
| return output | |
| class CNN(nn.Module): | |
| def __init__(self, input_shape=(-1, 6000, 3), activation=nn.ReLU(), downsample=1, | |
| mlp_input=11665, mlp_dims=(500, 300, 200, 150), eps=1e-8): | |
| super(CNN, self).__init__() | |
| self.input_shape = input_shape | |
| self.activation = activation | |
| self.downsample = downsample | |
| self.mlp_input = mlp_input | |
| self.mlp_dims = mlp_dims | |
| self.eps = eps | |
| self.lambda_layer_1 = LambdaLayer( | |
| lambda t: t / ( | |
| torch.max(torch.max(torch.abs(t), dim=1, keepdim=True).values, | |
| dim=2, keepdim=True).values + self.eps) | |
| ) | |
| self.unsqueeze_layer1 = LambdaLayer(lambda t: torch.unsqueeze(t, dim=1)) | |
| self.lambda_layer_2 = LambdaLayer( | |
| lambda t: torch.log(torch.max(torch.max(torch.abs(t), dim=1).values, | |
| dim=1).values + self.eps) / 100 | |
| ) | |
| self.unsqueeze_layer2 = LambdaLayer(lambda t: torch.unsqueeze(t, dim=1)) | |
| self.conv2d1 = nn.Sequential( | |
| nn.Conv2d(1, 8, kernel_size=(1, downsample), stride=(1, downsample)), | |
| nn.ReLU()) | |
| self.conv2d2 = nn.Sequential( | |
| nn.Conv2d(8, 32, kernel_size=(16, 3), stride=(1, 3)), nn.ReLU()) | |
| self.conv1d1 = nn.Sequential(nn.Conv1d(32, 64, kernel_size=16), nn.ReLU()) | |
| self.maxpooling = nn.MaxPool1d(2) | |
| self.conv1d2 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=16), nn.ReLU()) | |
| self.conv1d3 = nn.Sequential(nn.Conv1d(128, 32, kernel_size=8), nn.ReLU()) | |
| self.conv1d4 = nn.Sequential(nn.Conv1d(32, 32, kernel_size=8), nn.ReLU()) | |
| self.conv1d5 = nn.Sequential(nn.Conv1d(32, 16, kernel_size=4), nn.ReLU()) | |
| self.mlp = MLP((self.mlp_input,), dims=self.mlp_dims) | |
| def forward(self, x): | |
| output = self.lambda_layer_1(x) | |
| output = self.unsqueeze_layer1(output) | |
| scale = self.lambda_layer_2(x) | |
| scale = self.unsqueeze_layer2(scale) | |
| output = self.conv2d1(output) | |
| output = self.conv2d2(output) | |
| output = torch.squeeze(output, dim=-1) | |
| output = self.conv1d1(output) | |
| output = self.maxpooling(output) | |
| output = self.conv1d2(output) | |
| output = self.maxpooling(output) | |
| output = self.conv1d3(output) | |
| output = self.maxpooling(output) | |
| output = self.conv1d4(output) | |
| output = self.conv1d5(output) | |
| output = torch.flatten(output, start_dim=1) | |
| output = torch.cat((output, scale), dim=1) | |
| output = self.mlp(output) | |
| return output | |
| class PositionEmbeddingVs30(nn.Module): | |
| def __init__(self, wavelengths=((5, 30), (110, 123), (0.01, 5000), (100, 1600)), | |
| emb_dim=500): | |
| super(PositionEmbeddingVs30, self).__init__() | |
| self.wavelengths = wavelengths | |
| self.emb_dim = emb_dim | |
| min_lat, max_lat = wavelengths[0] | |
| min_lon, max_lon = wavelengths[1] | |
| min_depth, max_depth = wavelengths[2] | |
| min_vs30, max_vs30 = wavelengths[3] | |
| assert emb_dim % 10 == 0 | |
| lat_dim = emb_dim // 5 | |
| lon_dim = emb_dim // 5 | |
| depth_dim = emb_dim // 10 | |
| vs30_dim = emb_dim // 10 | |
| self.lat_coeff = 2 * np.pi * 1.0 / min_lat * ( | |
| (min_lat / max_lat) ** (np.arange(lat_dim) / lat_dim)) | |
| self.lon_coeff = 2 * np.pi * 1.0 / min_lon * ( | |
| (min_lon / max_lon) ** (np.arange(lon_dim) / lon_dim)) | |
| self.depth_coeff = 2 * np.pi * 1.0 / min_depth * ( | |
| (min_depth / max_depth) ** (np.arange(depth_dim) / depth_dim)) | |
| self.vs30_coeff = 2 * np.pi * 1.0 / min_vs30 * ( | |
| (min_vs30 / max_vs30) ** (np.arange(vs30_dim) / vs30_dim)) | |
| lat_sin_mask = np.arange(emb_dim) % 5 == 0 | |
| lat_cos_mask = np.arange(emb_dim) % 5 == 1 | |
| lon_sin_mask = np.arange(emb_dim) % 5 == 2 | |
| lon_cos_mask = np.arange(emb_dim) % 5 == 3 | |
| depth_sin_mask = np.arange(emb_dim) % 10 == 4 | |
| depth_cos_mask = np.arange(emb_dim) % 10 == 9 | |
| vs30_sin_mask = np.arange(emb_dim) % 10 == 5 | |
| vs30_cos_mask = np.arange(emb_dim) % 10 == 8 | |
| self.mask = np.zeros(emb_dim) | |
| self.mask[lat_sin_mask] = np.arange(lat_dim) | |
| self.mask[lat_cos_mask] = lat_dim + np.arange(lat_dim) | |
| self.mask[lon_sin_mask] = 2 * lat_dim + np.arange(lon_dim) | |
| self.mask[lon_cos_mask] = 2 * lat_dim + lon_dim + np.arange(lon_dim) | |
| self.mask[depth_sin_mask] = 2 * lat_dim + 2 * lon_dim + np.arange(depth_dim) | |
| self.mask[depth_cos_mask] = 2 * lat_dim + 2 * lon_dim + depth_dim + np.arange( | |
| depth_dim) | |
| self.mask[ | |
| vs30_sin_mask] = 2 * lat_dim + 2 * lon_dim + 2 * depth_dim + np.arange( | |
| vs30_dim) | |
| self.mask[ | |
| vs30_cos_mask] = 2 * lat_dim + 2 * lon_dim + 2 * depth_dim + vs30_dim + np.arange( | |
| vs30_dim) | |
| self.mask = self.mask.astype("int32") | |
| def forward(self, x): | |
| lat_base = x[:, :, 0:1].to(device) * torch.Tensor(self.lat_coeff).to(device) | |
| lon_base = x[:, :, 1:2].to(device) * torch.Tensor(self.lon_coeff).to(device) | |
| depth_base = x[:, :, 2:3].to(device) * torch.Tensor(self.depth_coeff).to(device) | |
| vs30_base = x[:, :, 3:4] * torch.Tensor(self.vs30_coeff).to(device) | |
| output = torch.cat([ | |
| torch.sin(lat_base), torch.cos(lat_base), | |
| torch.sin(lon_base), torch.cos(lon_base), | |
| torch.sin(depth_base), torch.cos(depth_base), | |
| torch.sin(vs30_base), torch.cos(vs30_base), | |
| ], dim=-1) | |
| maskk = torch.from_numpy(np.array(self.mask)).long() | |
| index = (maskk.unsqueeze(0).unsqueeze(0)).expand(x.shape[0], 1, | |
| self.emb_dim).to(device) | |
| output = torch.gather(output, -1, index).to(device) | |
| return output | |
| class TransformerEncoder(nn.Module): | |
| def __init__(self, d_model=150, nhead=10, batch_first=True, activation="gelu", | |
| dropout=0.0, dim_feedforward=1000): | |
| super(TransformerEncoder, self).__init__() | |
| self.encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=d_model, nhead=nhead, batch_first=batch_first, | |
| activation=activation, dropout=dropout, dim_feedforward=dim_feedforward | |
| ).to(device) | |
| self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, 6).to( | |
| device) | |
| def forward(self, x, src_key_padding_mask=None): | |
| return self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask) | |
| class MDN(nn.Module): | |
| def __init__(self, input_shape=(150,), n_hidden=20, n_gaussians=5): | |
| super(MDN, self).__init__() | |
| self.z_h = nn.Sequential(nn.Linear(input_shape[0], n_hidden), nn.Tanh()) | |
| self.z_weight = nn.Linear(n_hidden, n_gaussians) | |
| self.z_sigma = nn.Linear(n_hidden, n_gaussians) | |
| self.z_mu = nn.Linear(n_hidden, n_gaussians) | |
| def forward(self, x): | |
| z_h = self.z_h(x) | |
| weight = nn.functional.softmax(self.z_weight(z_h), -1) | |
| sigma = torch.exp(self.z_sigma(z_h)) | |
| mu = self.z_mu(z_h) | |
| return weight, sigma, mu | |
| class FullModel(nn.Module): | |
| def __init__(self, model_cnn, model_position, model_transformer, model_mlp, | |
| model_mdn, | |
| max_station=25, pga_targets=15, emb_dim=150, data_length=6000): | |
| super(FullModel, self).__init__() | |
| self.data_length = data_length | |
| self.model_CNN = model_cnn | |
| self.model_Position = model_position | |
| self.model_Transformer = model_transformer | |
| self.model_mlp = model_mlp | |
| self.model_MDN = model_mdn | |
| self.max_station = max_station | |
| self.pga_targets = pga_targets | |
| self.emb_dim = emb_dim | |
| def forward(self, data): | |
| cnn_output = self.model_CNN( | |
| torch.DoubleTensor( | |
| data["waveform"].reshape(-1, self.data_length, 3)).float().to(device) | |
| ) | |
| cnn_output_reshape = torch.reshape(cnn_output, | |
| (-1, self.max_station, self.emb_dim)) | |
| emb_output = self.model_Position( | |
| torch.DoubleTensor( | |
| data["station"].reshape(-1, 1, data["station"].shape[2])).float().to( | |
| device) | |
| ) | |
| emb_output = emb_output.reshape(-1, self.max_station, self.emb_dim) | |
| station_pad_mask = data["station"] == 0 | |
| station_pad_mask = torch.all(station_pad_mask, 2) | |
| pga_pos_emb_output = self.model_Position( | |
| torch.DoubleTensor( | |
| data["target"].reshape(-1, 1, data["target"].shape[2])).float().to( | |
| device) | |
| ) | |
| pga_pos_emb_output = pga_pos_emb_output.reshape(-1, self.pga_targets, | |
| self.emb_dim) | |
| target_pad_mask = torch.ones_like(data["target"], dtype=torch.bool) | |
| target_pad_mask = torch.all(target_pad_mask, 2) | |
| pad_mask = torch.cat((station_pad_mask, target_pad_mask), dim=1).to(device) | |
| add_pe_cnn_output = torch.add(cnn_output_reshape, emb_output) | |
| transformer_input = torch.cat((add_pe_cnn_output, pga_pos_emb_output), dim=1) | |
| transformer_output = self.model_Transformer(transformer_input, pad_mask) | |
| mlp_input = transformer_output[:, -self.pga_targets:, :].to(device) | |
| mlp_output = self.model_mlp(mlp_input) | |
| weight, sigma, mu = self.model_MDN(mlp_output) | |
| return weight, sigma, mu | |
| def get_full_model(model_path): | |
| emb_dim = 150 | |
| mlp_dims = (150, 100, 50, 30, 10) | |
| cnn_model = CNN(mlp_input=5665).to(device) | |
| pos_emb_model = PositionEmbeddingVs30(emb_dim=emb_dim).to(device) | |
| transformer_model = TransformerEncoder() | |
| mlp_model = MLP(input_shape=(emb_dim,), dims=mlp_dims).to(device) | |
| mdn_model = MDN(input_shape=(mlp_dims[-1],)).to(device) | |
| full_model = FullModel( | |
| cnn_model, pos_emb_model, transformer_model, mlp_model, mdn_model, | |
| pga_targets=25, data_length=3000 | |
| ).to(device) | |
| full_model.load_state_dict( | |
| torch.load(model_path, weights_only=True, map_location=device)) | |
| return full_model | |
| # 載入模型 | |
| model_path = hf_hub_download( | |
| repo_id="SeisBlue/TTSAM", | |
| filename="ttsam_trained_model_11.pt" | |
| ) | |
| model = get_full_model(model_path) | |
| # ============ 輔助函數 ============ | |
| def lowpass(data, freq=10, df=100, corners=4): | |
| fe = 0.5 * df | |
| f = freq / fe | |
| if f > 1: | |
| f = 1.0 | |
| z, p, k = iirfilter(corners, f, btype="lowpass", ftype="butter", output="zpk") | |
| sos = zpk2sos(z, p, k) | |
| return sosfilt(sos, data) | |
| def signal_processing(waveform): | |
| data = detrend(waveform, type="constant") | |
| data = lowpass(data, freq=10) | |
| return data | |
| def get_vs30(lat, lon, user_vs30=600): | |
| if tree is None or vs30_table is None: | |
| # 如果 Vs30 資料未載入,使用使用者輸入的值 | |
| logger.info(f"使用使用者輸入的 Vs30 值 ({user_vs30} m/s) for ({lat}, {lon})") | |
| return float(user_vs30) | |
| distance, i = tree.query([float(lat), float(lon)]) | |
| vs30 = vs30_table.iloc[i]["Vs30"] | |
| logger.info(f"從資料庫查詢到 Vs30 值 ({vs30} m/s) for ({lat}, {lon})") | |
| return float(vs30) | |
| def calculate_intensity(pga, label=False): | |
| intensity_label = ["0", "1", "2", "3", "4", "5-", "5+", "6-", "6+", "7"] | |
| pga_level = np.log10([1e-5, 0.008, 0.025, 0.080, 0.250, 0.80, 1.4, 2.5, 4.4, 8.0]) | |
| pga_intensity = np.searchsorted(pga_level, pga) - 1 | |
| intensity = pga_intensity | |
| if label: | |
| return intensity_label[intensity] | |
| else: | |
| return intensity | |
| # ============ Gradio 介面函數 ============ | |
| def load_waveform(event_name): | |
| """載入完整的 mseed 檔案(包含所有測站)""" | |
| file_path = EARTHQUAKE_EVENTS[event_name] | |
| st = read(file_path) | |
| return st | |
| def calculate_distance(lat1, lon1, lat2, lon2): | |
| """計算兩點間的距離(簡化的平面距離,單位:度)""" | |
| return np.sqrt((lat1 - lat2)**2 + (lon1 - lon2)**2) | |
| def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25): | |
| """從 site_info(1000+ 個輸入測站)中選擇距離震央最近的 n 個測站""" | |
| station_distances = {} # 改用字典避免重複 | |
| # 計算每個測站到震央的距離 | |
| for tr in st: | |
| station_code = tr.stats.station | |
| # 如果這個測站已經處理過,跳過(避免重複計算不同分量) | |
| if station_code in station_distances: | |
| continue | |
| # 從 site_info 中查詢測站位置 | |
| try: | |
| station_data = site_info[site_info["Station"] == station_code] | |
| if len(station_data) == 0: | |
| continue | |
| lat = station_data["Latitude"].values[0] | |
| lon = station_data["Longitude"].values[0] | |
| elev = station_data["Elevation"].values[0] | |
| distance = calculate_distance(epicenter_lat, epicenter_lon, lat, lon) | |
| station_distances[station_code] = { | |
| "station": station_code, | |
| "distance": distance, | |
| "latitude": lat, | |
| "longitude": lon, | |
| "elevation": elev | |
| } | |
| except Exception as e: | |
| logger.warning(f"測站 {station_code} 資訊查詢失敗: {e}") | |
| continue | |
| # 轉換為列表並按距離排序,選擇最近的 n 個 | |
| station_list = list(station_distances.values()) | |
| station_list.sort(key=lambda x: x["distance"]) | |
| selected_stations = station_list[:n_stations] | |
| logger.info(f"從 {len(station_list)} 個輸入測站中選擇了最近的 {len(selected_stations)} 個") | |
| return selected_stations | |
| def extract_waveforms_from_stream(st, selected_stations, start_time, end_time, vs30_input): | |
| """從 Stream 中提取選定測站的波形資料""" | |
| waveforms = [] | |
| station_info_list = [] | |
| valid_stations = [] | |
| sampling_rate = 100 # 假設 100 Hz | |
| start_idx = int(start_time * sampling_rate) | |
| end_idx = int(end_time * sampling_rate) | |
| target_length = 3000 | |
| for station_data in selected_stations: | |
| station_code = station_data["station"] | |
| try: | |
| # 選擇該測站的所有分量 | |
| st_station = st.select(station=station_code) | |
| if len(st_station) == 0: | |
| continue | |
| # 嘗試取得 Z, N, E 分量 | |
| z_trace = st_station.select(component="Z") | |
| n_trace = st_station.select(component="N") or st_station.select(component="1") | |
| e_trace = st_station.select(component="E") or st_station.select(component="2") | |
| # 如果沒有三分量,使用 Z 分量重複 | |
| if len(z_trace) > 0: | |
| z_data = z_trace[0].data[start_idx:end_idx] | |
| else: | |
| continue | |
| if len(n_trace) > 0: | |
| n_data = n_trace[0].data[start_idx:end_idx] | |
| else: | |
| n_data = z_data.copy() | |
| if len(e_trace) > 0: | |
| e_data = e_trace[0].data[start_idx:end_idx] | |
| else: | |
| e_data = z_data.copy() | |
| # 訊號處理 | |
| z_data = signal_processing(z_data) | |
| n_data = signal_processing(n_data) | |
| e_data = signal_processing(e_data) | |
| # 先創建全零陣列 (3000, 3) | |
| waveform_3c = np.zeros((target_length, 3)) | |
| # 填入實際資料(自動處理長度不足或過長的情況) | |
| z_len = min(len(z_data), target_length) | |
| n_len = min(len(n_data), target_length) | |
| e_len = min(len(e_data), target_length) | |
| waveform_3c[:z_len, 0] = z_data[:z_len] | |
| waveform_3c[:n_len, 1] = n_data[:n_len] | |
| waveform_3c[:e_len, 2] = e_data[:e_len] | |
| waveforms.append(waveform_3c) | |
| # 準備測站資訊 | |
| vs30 = get_vs30(station_data["latitude"], station_data["longitude"], vs30_input) | |
| station_info_list.append([ | |
| station_data["latitude"], | |
| station_data["longitude"], | |
| station_data["elevation"], | |
| vs30 | |
| ]) | |
| valid_stations.append(station_data) | |
| except Exception as e: | |
| logger.warning(f"測站 {station_code} 波形提取失敗: {e}") | |
| continue | |
| logger.info(f"成功提取 {len(waveforms)} 個測站的波形") | |
| return waveforms, station_info_list, valid_stations | |
| def plot_waveform(st, selected_stations, start_time, end_time): | |
| """繪製選定測站的波形圖(距離-時間圖,可顯示全部 25 個測站)""" | |
| fig, ax = plt.subplots(figsize=(14, 10)) | |
| # 設定振幅縮放比例(避免波形重疊) | |
| amplitude_scale = 0.03 # 可調整此值來控制波形大小 | |
| plotted_count = 0 | |
| distances = [] | |
| station_names = [] | |
| for i, station_data in enumerate(selected_stations): | |
| station_code = station_data["station"] | |
| distance = station_data["distance"] | |
| try: | |
| st_station = st.select(station=station_code) | |
| if len(st_station) > 0: | |
| tr = st_station[0] | |
| times = tr.times() | |
| data = tr.data | |
| # 正規化波形振幅 | |
| data_normalized = data / (np.max(np.abs(data)) + 1e-10) | |
| # 繪製波形,Y軸位置為距離 | |
| ax.plot(times, distance + data_normalized * amplitude_scale, | |
| 'black', linewidth=0.3, alpha=0.8) | |
| distances.append(distance) | |
| station_names.append(station_code) | |
| plotted_count += 1 | |
| except Exception as e: | |
| logger.warning(f"無法繪製測站 {station_code}: {e}") | |
| # 標記選取時間範圍 | |
| ax.axvline(start_time, color='red', linestyle='--', linewidth=2, | |
| alpha=0.7, label='選取範圍') | |
| ax.axvline(end_time, color='red', linestyle='--', linewidth=2, alpha=0.7) | |
| ax.axvspan(start_time, end_time, alpha=0.15, color='blue') | |
| # 設定軸標籤和標題 | |
| ax.set_xlabel('Time (s)', fontsize=12) | |
| ax.set_ylabel('Distance from Epicenter (°)', fontsize=12) | |
| ax.set_title(f'Record Section - {plotted_count} Stations Sorted by Distance', | |
| fontsize=14, fontweight='bold') | |
| # 在右側標註測站名稱 | |
| if distances: | |
| ax2 = ax.twinx() | |
| ax2.set_ylim(ax.get_ylim()) | |
| ax2.set_ylabel('Station Code', fontsize=12) | |
| # 每隔幾個測站標註一次(避免過於擁擠) | |
| step = max(1, len(distances) // 10) | |
| tick_positions = distances[::step] | |
| tick_labels = station_names[::step] | |
| ax2.set_yticks(tick_positions) | |
| ax2.set_yticklabels(tick_labels, fontsize=8) | |
| ax.grid(True, alpha=0.3, axis='x') | |
| ax.legend(loc='upper right') | |
| plt.tight_layout() | |
| return fig | |
| def get_intensity_color(intensity): | |
| """根據震度等級返回對應顏色(參考 intensityMap.html)""" | |
| color_map = { | |
| 0: "#ffffff", # 白色 | |
| 1: "#33FFDD", # 青色 | |
| 2: "#34ff32", # 綠色 | |
| 3: "#fefd32", # 黃色 | |
| 4: "#fe8532", # 橘色 | |
| 5: "#fd5233", # 紅橘色 (5-) | |
| 6: "#c43f3b", # 深紅色 (5+) | |
| 7: "#9d4646", # 暗紅色 (6-) | |
| 8: "#9a4c86", # 紫紅色 (6+) | |
| 9: "#b51fea", # 紫色 (7) | |
| } | |
| return color_map.get(intensity, "#ffffff") | |
| def create_intensity_map(pga_list, target_names, epicenter_lat=None, epicenter_lon=None): | |
| """使用 Folium 創建互動式震度分布地圖""" | |
| import folium | |
| from folium import plugins | |
| # 創建地圖,中心點設在台灣中心,設定地圖尺寸 | |
| m = folium.Map( | |
| location=[23.5, 121], | |
| zoom_start=7, | |
| tiles='OpenStreetMap', | |
| width='100%', | |
| height='600px' # 設定固定高度,與 Ground Truth 圖片匹配 | |
| ) | |
| # 如果有震央位置,標記震央 | |
| if epicenter_lat and epicenter_lon: | |
| folium.Marker( | |
| [epicenter_lat, epicenter_lon], | |
| popup=f'震央<br>({epicenter_lat:.3f}, {epicenter_lon:.3f})', | |
| icon=folium.Icon(color='red', icon='star', prefix='fa'), | |
| tooltip='震央位置' | |
| ).add_to(m) | |
| # 添加震度測站標記 | |
| for i, target_name in enumerate(target_names): | |
| target = next((t for t in target_dict if t["station"] == target_name), None) | |
| if target: | |
| lat = target["latitude"] | |
| lon = target["longitude"] | |
| intensity = calculate_intensity(pga_list[i]) | |
| intensity_label = calculate_intensity(pga_list[i], label=True) | |
| color = get_intensity_color(intensity) | |
| pga = pga_list[i] | |
| # 創建 HTML popup 內容 | |
| popup_html = f""" | |
| <div style="font-family: Arial; min-width: 150px;"> | |
| <h4 style="margin: 0 0 10px 0;">{target_name}</h4> | |
| <table style="width:100%;"> | |
| <tr><td><b>震度:</b></td><td style="color: {color}; font-weight: bold; font-size: 16px;">{intensity_label}</td></tr> | |
| <tr><td><b>PGA:</b></td><td>{pga:.4f} m/s²</td></tr> | |
| <tr><td><b>位置:</b></td><td>({lat:.3f}, {lon:.3f})</td></tr> | |
| </table> | |
| </div> | |
| """ | |
| # 創建圓形標記 | |
| folium.CircleMarker( | |
| location=[lat, lon], | |
| radius=12, | |
| popup=folium.Popup(popup_html, max_width=250), | |
| tooltip=f'{target_name}: 震度 {intensity_label}', | |
| color='black', | |
| fillColor=color, | |
| fillOpacity=0.8, | |
| weight=2 | |
| ).add_to(m) | |
| # 在圓圈中心添加震度文字 | |
| folium.Marker( | |
| [lat, lon], | |
| icon=folium.DivIcon(html=f''' | |
| <div style=" | |
| font-size: 10px; | |
| font-weight: bold; | |
| color: black; | |
| text-align: center; | |
| text-shadow: 1px 1px 2px white, -1px -1px 2px white; | |
| ">{intensity_label}</div> | |
| ''') | |
| ).add_to(m) | |
| # 添加圖例 | |
| legend_html = ''' | |
| <div style=" | |
| position: fixed; | |
| top: 10px; left: 10px; | |
| width: 180px; | |
| background-color: white; | |
| border: 2px solid grey; | |
| z-index: 9999; | |
| font-size: 14px; | |
| padding: 10px; | |
| border-radius: 5px; | |
| box-shadow: 2px 2px 6px rgba(0,0,0,0.3); | |
| "> | |
| <h4 style="margin: 0 0 10px 0;">震度等級 Intensity</h4> | |
| <table style="width: 100%;"> | |
| ''' | |
| intensity_levels = ["0", "1", "2", "3", "4", "5-", "5+", "6-", "6+", "7"] | |
| for idx, level in enumerate(intensity_levels): | |
| color = get_intensity_color(idx) | |
| legend_html += f''' | |
| <tr> | |
| <td style="width: 30px; height: 20px; background-color: {color}; border: 1px solid black;"></td> | |
| <td style="padding-left: 5px;">{level}</td> | |
| </tr> | |
| ''' | |
| legend_html += ''' | |
| </table> | |
| </div> | |
| ''' | |
| m.get_root().html.add_child(folium.Element(legend_html)) | |
| # 添加全屏按鈕 | |
| plugins.Fullscreen().add_to(m) | |
| return m | |
| def load_ground_truth_image(event_name): | |
| """從 ground_truth 資料夾載入對應的 Ground Truth 圖片""" | |
| import os | |
| # 根據事件名稱找對應的圖片 | |
| # 假設圖片命名格式為:20240403.png 或類似 | |
| event_file = EARTHQUAKE_EVENTS[event_name] | |
| event_date = os.path.basename(event_file).replace('.mseed', '') | |
| # 嘗試不同的圖片格式 | |
| ground_truth_dir = "ground_truth" | |
| possible_extensions = ['.png', '.jpg', '.jpeg', '.gif'] | |
| for ext in possible_extensions: | |
| image_path = os.path.join(ground_truth_dir, f"{event_date}{ext}") | |
| if os.path.exists(image_path): | |
| logger.info(f"載入 Ground Truth 圖片: {image_path}") | |
| return image_path | |
| logger.warning(f"找不到 Ground Truth 圖片: {event_date}") | |
| return None | |
| def create_input_station_map(selected_stations, epicenter_lat, epicenter_lon): | |
| """創建輸入測站分布地圖:顯示所有測站 + 突顯被選中的 25 個""" | |
| import folium | |
| from folium import plugins | |
| # 創建地圖,中心點設在震央 | |
| m = folium.Map( | |
| location=[epicenter_lat, epicenter_lon], | |
| zoom_start=8, | |
| tiles='OpenStreetMap', | |
| width='100%', | |
| height='500px' | |
| ) | |
| # 建立被選中測站的 set(用於快速查詢) | |
| selected_station_codes = {s["station"] for s in selected_stations} | |
| # 1. 先繪製所有測站(灰色小點) | |
| logger.info(f"繪製所有測站 ({len(site_info)} 個)...") | |
| for idx, row in site_info.iterrows(): | |
| station_code = row["Station"] | |
| lat = row["Latitude"] | |
| lon = row["Longitude"] | |
| # 跳過被選中的測站(稍後用不同樣式繪製) | |
| if station_code in selected_station_codes: | |
| continue | |
| folium.CircleMarker( | |
| location=[lat, lon], | |
| radius=2, | |
| popup=f'{station_code}', | |
| tooltip=station_code, | |
| color='gray', | |
| fillColor='lightgray', | |
| fillOpacity=0.4, | |
| weight=1 | |
| ).add_to(m) | |
| # 2. 標記震央(紅色星星) | |
| folium.Marker( | |
| [epicenter_lat, epicenter_lon], | |
| popup=f'<b>震央</b><br>({epicenter_lat:.3f}, {epicenter_lon:.3f})', | |
| icon=folium.Icon(color='red', icon='star', prefix='fa'), | |
| tooltip='震央位置', | |
| zIndexOffset=1000 | |
| ).add_to(m) | |
| # 3. 標記被選中的 25 個測站(彩色大點) | |
| for i, station_data in enumerate(selected_stations): | |
| station_code = station_data["station"] | |
| lat = station_data["latitude"] | |
| lon = station_data["longitude"] | |
| distance = station_data["distance"] | |
| # 創建 popup 內容 | |
| popup_html = f""" | |
| <div style="font-family: Arial; min-width: 150px;"> | |
| <h4 style="margin: 0 0 10px 0; color: #d63031;">{station_code}</h4> | |
| <table style="width:100%;"> | |
| <tr><td><b>狀態:</b></td><td><span style="color: #00b894;">✓ 已選中</span></td></tr> | |
| <tr><td><b>順序:</b></td><td>第 {i+1} 近</td></tr> | |
| <tr><td><b>距離:</b></td><td>{distance:.2f}°</td></tr> | |
| <tr><td><b>位置:</b></td><td>({lat:.3f}, {lon:.3f})</td></tr> | |
| </table> | |
| </div> | |
| """ | |
| # 根據距離設定顏色 | |
| if i < 5: | |
| color = 'green' | |
| elif i < 15: | |
| color = 'blue' | |
| else: | |
| color = 'orange' | |
| folium.CircleMarker( | |
| location=[lat, lon], | |
| radius=10, | |
| popup=folium.Popup(popup_html, max_width=250), | |
| tooltip=f'✓ {station_code} (第{i+1}近)', | |
| color='black', | |
| fillColor=color, | |
| fillOpacity=0.8, | |
| weight=2, | |
| zIndexOffset=500 | |
| ).add_to(m) | |
| # 4. 添加圖例 | |
| total_stations = len(site_info) | |
| legend_html = f''' | |
| <div style=" | |
| position: fixed; | |
| top: 10px; left: 10px; | |
| width: 220px; | |
| background-color: white; | |
| border: 2px solid grey; | |
| z-index: 9999; | |
| font-size: 13px; | |
| padding: 10px; | |
| border-radius: 5px; | |
| box-shadow: 2px 2px 6px rgba(0,0,0,0.3); | |
| "> | |
| <h4 style="margin: 0 0 10px 0;">測站分布</h4> | |
| <p style="margin: 5px 0;"><span style="color: red; font-size: 18px;">★</span> 震央</p> | |
| <p style="margin: 5px 0;"><span style="color: lightgray;">●</span> 所有測站 ({total_stations} 個)</p> | |
| <hr style="margin: 8px 0; border: none; border-top: 1px solid #ddd;"> | |
| <p style="margin: 5px 0; font-weight: bold;">被選中的測站:</p> | |
| <p style="margin: 5px 0;"><span style="color: green; font-size: 16px;">●</span> 前 5 近</p> | |
| <p style="margin: 5px 0;"><span style="color: blue; font-size: 16px;">●</span> 6-15 近</p> | |
| <p style="margin: 5px 0;"><span style="color: orange; font-size: 16px;">●</span> 16-25 近</p> | |
| <p style="margin: 5px 0; font-size: 11px; color: #666;">共選擇 {len(selected_stations)} 個測站</p> | |
| </div> | |
| ''' | |
| m.get_root().html.add_child(folium.Element(legend_html)) | |
| # 5. 添加全屏按鈕 | |
| plugins.Fullscreen().add_to(m) | |
| return m | |
| def load_and_display_waveform(event_name, start_time, end_time, epicenter_lon, epicenter_lat): | |
| """載入並顯示波形,讓使用者確認範圍""" | |
| try: | |
| # 1. 載入完整的 mseed 檔案 | |
| logger.info(f"載入地震事件: {event_name}") | |
| st = load_waveform(event_name) | |
| logger.info(f"載入了 {len(st)} 個 trace") | |
| # 2. 根據震央距離選擇最近的 25 個測站 | |
| logger.info(f"選擇距離震央 ({epicenter_lat}, {epicenter_lon}) 最近的測站...") | |
| selected_stations = select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25) | |
| if len(selected_stations) == 0: | |
| return None, "錯誤:找不到有效的測站資料", gr.update(interactive=False) | |
| # 3. 繪製波形 | |
| waveform_plot = plot_waveform(st, selected_stations, start_time, end_time) | |
| # 4. 創建輸入測站地圖 | |
| station_map = create_input_station_map(selected_stations, epicenter_lat, epicenter_lon) | |
| station_map_html = station_map._repr_html_() | |
| info_text = f"✅ 已載入波形資料\n" | |
| info_text += f"選取時間範圍: {start_time:.1f} - {end_time:.1f} 秒\n" | |
| info_text += f"震央位置: ({epicenter_lon:.4f}, {epicenter_lat:.4f})\n" | |
| info_text += f"選擇了 {len(selected_stations)} 個最近的測站\n" | |
| info_text += f"請確認波形範圍後,點擊「執行預測」按鈕" | |
| logger.info("波形載入完成") | |
| return station_map_html, waveform_plot, info_text, gr.update(interactive=True) | |
| except Exception as e: | |
| logger.error(f"波形載入發生錯誤: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, f"錯誤: {str(e)}", gr.update(interactive=False) | |
| def predict_intensity(event_name, start_time, end_time, epicenter_lon, epicenter_lat): | |
| """執行震度預測""" | |
| try: | |
| # 1. 載入完整的 mseed 檔案 | |
| logger.info(f"載入地震事件: {event_name}") | |
| st = load_waveform(event_name) | |
| logger.info(f"載入了 {len(st)} 個 trace") | |
| # 2. 根據震央距離選擇最近的 25 個測站 | |
| logger.info(f"選擇距離震央 ({epicenter_lat}, {epicenter_lon}) 最近的測站...") | |
| selected_stations = select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25) | |
| if len(selected_stations) == 0: | |
| return None, None, "錯誤:找不到有效的測站資料" | |
| # 3. 從選定的測站提取波形(vs30_input 使用預設值 600,會被資料庫值覆蓋) | |
| logger.info(f"提取波形資料(時間範圍: {start_time}-{end_time} 秒)...") | |
| waveforms, station_info_list, valid_stations = extract_waveforms_from_stream( | |
| st, selected_stations, start_time, end_time, vs30_input=600 | |
| ) | |
| if len(waveforms) == 0: | |
| return None, "錯誤:無法提取波形資料" | |
| # 4. Padding 到 25 個測站(模型要求) | |
| max_stations = 25 | |
| waveform_padded = np.zeros((max_stations, 3000, 3)) | |
| station_info_padded = np.zeros((max_stations, 4)) | |
| for i in range(min(len(waveforms), max_stations)): | |
| waveform_padded[i] = waveforms[i] | |
| station_info_padded[i] = station_info_list[i] | |
| # 5. 準備所有目標測站資訊(分批處理) | |
| all_pga_list = [] | |
| all_target_names = [] | |
| # 計算需要分幾批(每批 25 個測站) | |
| batch_size = 25 | |
| total_targets = len(target_dict) | |
| num_batches = (total_targets + batch_size - 1) // batch_size | |
| logger.info(f"開始分批預測 {total_targets} 個目標測站(共 {num_batches} 批)...") | |
| for batch_idx in range(num_batches): | |
| start_idx = batch_idx * batch_size | |
| end_idx = min((batch_idx + 1) * batch_size, total_targets) | |
| batch_targets = target_dict[start_idx:end_idx] | |
| logger.info(f"預測第 {batch_idx + 1}/{num_batches} 批(測站 {start_idx + 1}-{end_idx})...") | |
| # 準備這批目標測站資訊 | |
| target_list = [] | |
| target_names = [] | |
| for target in batch_targets: | |
| target_list.append([ | |
| target["latitude"], | |
| target["longitude"], | |
| target["elevation"], | |
| get_vs30(target["latitude"], target["longitude"], user_vs30=600) | |
| ]) | |
| target_names.append(target["station"]) | |
| # Padding 到 25 個(如果不足 25 個) | |
| target_padded = np.zeros((batch_size, 4)) | |
| for i in range(len(target_list)): | |
| target_padded[i] = target_list[i] | |
| # 6. 組合成 tensor | |
| tensor_data = { | |
| "waveform": torch.tensor(waveform_padded).unsqueeze(0).double(), | |
| "station": torch.tensor(station_info_padded).unsqueeze(0).double(), | |
| "target": torch.tensor(target_padded).unsqueeze(0).double(), | |
| } | |
| # 7. 執行預測 | |
| with torch.no_grad(): | |
| weight, sigma, mu = model(tensor_data) | |
| batch_pga = torch.sum(weight * mu, dim=2).cpu().detach().numpy().flatten().tolist() | |
| # 只取實際有資料的部分 | |
| all_pga_list.extend(batch_pga[:len(target_names)]) | |
| all_target_names.extend(target_names) | |
| logger.info(f"完成所有 {len(all_target_names)} 個測站的預測!") | |
| pga_list = all_pga_list | |
| target_names = all_target_names | |
| # 8. 繪製互動式地圖 | |
| intensity_map = create_intensity_map(pga_list, target_names, epicenter_lat, epicenter_lon) | |
| map_html = intensity_map._repr_html_() | |
| # 9. 載入 Ground Truth 圖片 | |
| ground_truth_path = load_ground_truth_image(event_name) | |
| # 10. 統計資訊 | |
| max_intensity = max([calculate_intensity(pga, label=True) for pga in pga_list]) | |
| stats = f"✅ 預測完成!\n" | |
| stats += f"選取時間範圍: {start_time:.1f} - {end_time:.1f} 秒\n" | |
| stats += f"震央位置: ({epicenter_lon:.4f}, {epicenter_lat:.4f})\n" | |
| stats += f"使用測站數: {len(waveforms)} / 25\n" | |
| stats += f"預測最大震度: {max_intensity}" | |
| logger.info("預測完成!") | |
| return ground_truth_path, map_html, stats | |
| except Exception as e: | |
| logger.error(f"預測過程發生錯誤: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, f"錯誤: {str(e)}" | |
| # ============ Gradio 介面 ============ | |
| with gr.Blocks(title="TTSAM 震度預測系統") as demo: | |
| gr.Markdown("# 🌏 TTSAM 震度預測系統") | |
| # ========== 上層:使用說明與參數設定 ========== | |
| with gr.Row(): | |
| # 左上:使用步驟與狀態顯示 | |
| with gr.Column(scale=1): | |
| gr.Markdown("## 使用步驟") | |
| gr.Markdown(""" | |
| 1. 選擇地震事件和時間範圍 | |
| 2. 輸入震央位置和場址參數 | |
| 3. 點擊「載入波形」確認波形範圍 | |
| 4. 確認無誤後,點擊「執行預測」 | |
| ℹ️ 系統會自動選擇距離震央最近的 25 個測站 | |
| """) | |
| info_output = gr.Textbox(label="狀態資訊", lines=6, interactive=False) | |
| stats_output = gr.Textbox(label="預測統計", lines=4, interactive=False) | |
| # 右上:輸入參數 | |
| with gr.Column(scale=1): | |
| gr.Markdown("## 輸入參數") | |
| event_dropdown = gr.Dropdown( | |
| choices=list(EARTHQUAKE_EVENTS.keys()), | |
| value=list(EARTHQUAKE_EVENTS.keys())[0], | |
| label="選擇地震事件" | |
| ) | |
| with gr.Row(): | |
| start_slider = gr.Slider(0, 300, value=0, step=1, label="起始時間 (秒)") | |
| end_slider = gr.Slider(0, 300, value=30, step=1, label="結束時間 (秒)") | |
| gr.Markdown("### 震央位置") | |
| with gr.Row(): | |
| epicenter_lon_input = gr.Number(value=121.57, label="震央經度") | |
| epicenter_lat_input = gr.Number(value=23.88, label="震央緯度") | |
| with gr.Row(): | |
| load_waveform_btn = gr.Button("📊 載入波形", variant="secondary", scale=1) | |
| predict_btn = gr.Button("🔮 執行預測", variant="primary", scale=1, interactive=False) | |
| # ========== 中層:輸入測站地圖與波形圖 ========== | |
| with gr.Row(): | |
| # 中左:輸入波形 | |
| with gr.Column(scale=1): | |
| gr.Markdown("## 輸入波形") | |
| waveform_plot = gr.Plot(label="地震波形(選定的 25 個測站)") | |
| # 中右:輸入測站地圖 | |
| with gr.Column(scale=1): | |
| gr.Markdown("## 輸入測站分布") | |
| input_station_map = gr.HTML(label="輸入測站地圖") | |
| # ========== 下層:Ground Truth vs 預測結果 ========== | |
| with gr.Row(): | |
| # 左下:Ground Truth | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Ground Truth 震度分布") | |
| ground_truth_image = gr.Image(label="實際觀測震度", type="filepath", height=600) | |
| # 右下:預測震度地圖 | |
| with gr.Column(scale=1): | |
| gr.Markdown("## 預測震度分布") | |
| intensity_map = gr.HTML(label="互動式震度地圖", elem_id="intensity_map") | |
| # 綁定事件 | |
| # 第一步:載入波形 | |
| load_waveform_btn.click( | |
| fn=load_and_display_waveform, | |
| inputs=[event_dropdown, start_slider, end_slider, epicenter_lon_input, epicenter_lat_input], | |
| outputs=[input_station_map, waveform_plot, info_output, predict_btn] | |
| ) | |
| # 第二步:執行預測 | |
| predict_btn.click( | |
| fn=predict_intensity, | |
| inputs=[event_dropdown, start_slider, end_slider, epicenter_lon_input, epicenter_lat_input], | |
| outputs=[ground_truth_image, intensity_map, stats_output] | |
| ) | |
| demo.launch() | |