import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from obspy import read
import xarray as xr
import torch
import torch.nn as nn
from scipy.signal import detrend, iirfilter, sosfilt, zpk2sos
from scipy.spatial import cKDTree
import pandas as pd
from loguru import logger
# 設定 matplotlib 中文字體支援
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False # 解決負號顯示問題
# GPU/CPU 設定
if torch.cuda.is_available():
device = torch.device("cuda")
logger.info("使用 GPU")
else:
device = torch.device("cpu")
logger.info("使用 CPU")
# 載入 Vs30 資料集(從 Hugging Face 下載)
from huggingface_hub import hf_hub_download
tree = None
vs30_table = None
try:
logger.info("從 Hugging Face 載入 Vs30 資料...")
vs30_file = hf_hub_download(
repo_id="SeisBlue/TaiwanVs30",
filename="Vs30ofTaiwan.nc",
repo_type="dataset"
)
ds = xr.open_dataset(vs30_file)
lat_flat = ds['lat'].values.flatten()
lon_flat = ds['lon'].values.flatten()
vs30_flat = ds['vs30'].values.flatten()
vs30_table = pd.DataFrame({'lat': lat_flat, 'lon': lon_flat, 'Vs30': vs30_flat})
vs30_table = vs30_table.replace([np.inf, -np.inf], np.nan).dropna()
tree = cKDTree(vs30_table[["lat", "lon"]])
logger.info("Vs30 資料載入完成")
except Exception as e:
logger.warning(f"Vs30 資料載入失敗: {e}")
logger.warning("將使用預設 Vs30 值 (600 m/s)")
# 載入目標測站
target_file = "station/eew_target.csv"
try:
logger.info(f"載入 {target_file}...")
target_df = pd.read_csv(target_file)
target_dict = target_df.to_dict(orient="records")
logger.info(f"{target_file} 載入完成")
except FileNotFoundError:
logger.error(f"{target_file} 找不到")
# 載入測站資訊(輸入測站,1000+ 個)
site_info_file = "station/site_info.csv"
try:
logger.info(f"載入 {site_info_file}...")
site_info = pd.read_csv(site_info_file)
# 只保留唯一的測站(去除重複的分量)
site_info = site_info.drop_duplicates(subset=['Station']).reset_index(drop=True)
logger.info(f"{site_info_file} 載入完成,共 {len(site_info)} 個測站")
except FileNotFoundError:
logger.warning(f"{site_info_file} 找不到")
# 預設地震事件
EARTHQUAKE_EVENTS = {
"0403花蓮地震 (2024)": "waveform/20240403.mseed",
}
# ============ 模型定義(從 ttsam_realtime.py 複製) ============
class LambdaLayer(nn.Module):
def __init__(self, lambd, eps=1e-4):
super(LambdaLayer, self).__init__()
self.lambd = lambd
self.eps = eps
def forward(self, x):
return self.lambd(x) + self.eps
class MLP(nn.Module):
def __init__(self, input_shape, dims=(500, 300, 200, 150), activation=nn.ReLU(),
last_activation=None):
super(MLP, self).__init__()
if last_activation is None:
last_activation = activation
self.dims = dims
self.first_fc = nn.Linear(input_shape[0], dims[0])
self.first_activation = activation
more_hidden = []
if len(self.dims) > 2:
for i in range(1, len(self.dims) - 1):
more_hidden.append(nn.Linear(self.dims[i - 1], self.dims[i]))
more_hidden.append(nn.ReLU())
self.more_hidden = nn.ModuleList(more_hidden)
self.last_fc = nn.Linear(dims[-2], dims[-1])
self.last_activation = last_activation
def forward(self, x):
output = self.first_fc(x)
output = self.first_activation(output)
if self.more_hidden:
for layer in self.more_hidden:
output = layer(output)
output = self.last_fc(output)
output = self.last_activation(output)
return output
class CNN(nn.Module):
def __init__(self, input_shape=(-1, 6000, 3), activation=nn.ReLU(), downsample=1,
mlp_input=11665, mlp_dims=(500, 300, 200, 150), eps=1e-8):
super(CNN, self).__init__()
self.input_shape = input_shape
self.activation = activation
self.downsample = downsample
self.mlp_input = mlp_input
self.mlp_dims = mlp_dims
self.eps = eps
self.lambda_layer_1 = LambdaLayer(
lambda t: t / (
torch.max(torch.max(torch.abs(t), dim=1, keepdim=True).values,
dim=2, keepdim=True).values + self.eps)
)
self.unsqueeze_layer1 = LambdaLayer(lambda t: torch.unsqueeze(t, dim=1))
self.lambda_layer_2 = LambdaLayer(
lambda t: torch.log(torch.max(torch.max(torch.abs(t), dim=1).values,
dim=1).values + self.eps) / 100
)
self.unsqueeze_layer2 = LambdaLayer(lambda t: torch.unsqueeze(t, dim=1))
self.conv2d1 = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=(1, downsample), stride=(1, downsample)),
nn.ReLU())
self.conv2d2 = nn.Sequential(
nn.Conv2d(8, 32, kernel_size=(16, 3), stride=(1, 3)), nn.ReLU())
self.conv1d1 = nn.Sequential(nn.Conv1d(32, 64, kernel_size=16), nn.ReLU())
self.maxpooling = nn.MaxPool1d(2)
self.conv1d2 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=16), nn.ReLU())
self.conv1d3 = nn.Sequential(nn.Conv1d(128, 32, kernel_size=8), nn.ReLU())
self.conv1d4 = nn.Sequential(nn.Conv1d(32, 32, kernel_size=8), nn.ReLU())
self.conv1d5 = nn.Sequential(nn.Conv1d(32, 16, kernel_size=4), nn.ReLU())
self.mlp = MLP((self.mlp_input,), dims=self.mlp_dims)
def forward(self, x):
output = self.lambda_layer_1(x)
output = self.unsqueeze_layer1(output)
scale = self.lambda_layer_2(x)
scale = self.unsqueeze_layer2(scale)
output = self.conv2d1(output)
output = self.conv2d2(output)
output = torch.squeeze(output, dim=-1)
output = self.conv1d1(output)
output = self.maxpooling(output)
output = self.conv1d2(output)
output = self.maxpooling(output)
output = self.conv1d3(output)
output = self.maxpooling(output)
output = self.conv1d4(output)
output = self.conv1d5(output)
output = torch.flatten(output, start_dim=1)
output = torch.cat((output, scale), dim=1)
output = self.mlp(output)
return output
class PositionEmbeddingVs30(nn.Module):
def __init__(self, wavelengths=((5, 30), (110, 123), (0.01, 5000), (100, 1600)),
emb_dim=500):
super(PositionEmbeddingVs30, self).__init__()
self.wavelengths = wavelengths
self.emb_dim = emb_dim
min_lat, max_lat = wavelengths[0]
min_lon, max_lon = wavelengths[1]
min_depth, max_depth = wavelengths[2]
min_vs30, max_vs30 = wavelengths[3]
assert emb_dim % 10 == 0
lat_dim = emb_dim // 5
lon_dim = emb_dim // 5
depth_dim = emb_dim // 10
vs30_dim = emb_dim // 10
self.lat_coeff = 2 * np.pi * 1.0 / min_lat * (
(min_lat / max_lat) ** (np.arange(lat_dim) / lat_dim))
self.lon_coeff = 2 * np.pi * 1.0 / min_lon * (
(min_lon / max_lon) ** (np.arange(lon_dim) / lon_dim))
self.depth_coeff = 2 * np.pi * 1.0 / min_depth * (
(min_depth / max_depth) ** (np.arange(depth_dim) / depth_dim))
self.vs30_coeff = 2 * np.pi * 1.0 / min_vs30 * (
(min_vs30 / max_vs30) ** (np.arange(vs30_dim) / vs30_dim))
lat_sin_mask = np.arange(emb_dim) % 5 == 0
lat_cos_mask = np.arange(emb_dim) % 5 == 1
lon_sin_mask = np.arange(emb_dim) % 5 == 2
lon_cos_mask = np.arange(emb_dim) % 5 == 3
depth_sin_mask = np.arange(emb_dim) % 10 == 4
depth_cos_mask = np.arange(emb_dim) % 10 == 9
vs30_sin_mask = np.arange(emb_dim) % 10 == 5
vs30_cos_mask = np.arange(emb_dim) % 10 == 8
self.mask = np.zeros(emb_dim)
self.mask[lat_sin_mask] = np.arange(lat_dim)
self.mask[lat_cos_mask] = lat_dim + np.arange(lat_dim)
self.mask[lon_sin_mask] = 2 * lat_dim + np.arange(lon_dim)
self.mask[lon_cos_mask] = 2 * lat_dim + lon_dim + np.arange(lon_dim)
self.mask[depth_sin_mask] = 2 * lat_dim + 2 * lon_dim + np.arange(depth_dim)
self.mask[depth_cos_mask] = 2 * lat_dim + 2 * lon_dim + depth_dim + np.arange(
depth_dim)
self.mask[
vs30_sin_mask] = 2 * lat_dim + 2 * lon_dim + 2 * depth_dim + np.arange(
vs30_dim)
self.mask[
vs30_cos_mask] = 2 * lat_dim + 2 * lon_dim + 2 * depth_dim + vs30_dim + np.arange(
vs30_dim)
self.mask = self.mask.astype("int32")
def forward(self, x):
lat_base = x[:, :, 0:1].to(device) * torch.Tensor(self.lat_coeff).to(device)
lon_base = x[:, :, 1:2].to(device) * torch.Tensor(self.lon_coeff).to(device)
depth_base = x[:, :, 2:3].to(device) * torch.Tensor(self.depth_coeff).to(device)
vs30_base = x[:, :, 3:4] * torch.Tensor(self.vs30_coeff).to(device)
output = torch.cat([
torch.sin(lat_base), torch.cos(lat_base),
torch.sin(lon_base), torch.cos(lon_base),
torch.sin(depth_base), torch.cos(depth_base),
torch.sin(vs30_base), torch.cos(vs30_base),
], dim=-1)
maskk = torch.from_numpy(np.array(self.mask)).long()
index = (maskk.unsqueeze(0).unsqueeze(0)).expand(x.shape[0], 1,
self.emb_dim).to(device)
output = torch.gather(output, -1, index).to(device)
return output
class TransformerEncoder(nn.Module):
def __init__(self, d_model=150, nhead=10, batch_first=True, activation="gelu",
dropout=0.0, dim_feedforward=1000):
super(TransformerEncoder, self).__init__()
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, batch_first=batch_first,
activation=activation, dropout=dropout, dim_feedforward=dim_feedforward
).to(device)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, 6).to(
device)
def forward(self, x, src_key_padding_mask=None):
return self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
class MDN(nn.Module):
def __init__(self, input_shape=(150,), n_hidden=20, n_gaussians=5):
super(MDN, self).__init__()
self.z_h = nn.Sequential(nn.Linear(input_shape[0], n_hidden), nn.Tanh())
self.z_weight = nn.Linear(n_hidden, n_gaussians)
self.z_sigma = nn.Linear(n_hidden, n_gaussians)
self.z_mu = nn.Linear(n_hidden, n_gaussians)
def forward(self, x):
z_h = self.z_h(x)
weight = nn.functional.softmax(self.z_weight(z_h), -1)
sigma = torch.exp(self.z_sigma(z_h))
mu = self.z_mu(z_h)
return weight, sigma, mu
class FullModel(nn.Module):
def __init__(self, model_cnn, model_position, model_transformer, model_mlp,
model_mdn,
max_station=25, pga_targets=15, emb_dim=150, data_length=6000):
super(FullModel, self).__init__()
self.data_length = data_length
self.model_CNN = model_cnn
self.model_Position = model_position
self.model_Transformer = model_transformer
self.model_mlp = model_mlp
self.model_MDN = model_mdn
self.max_station = max_station
self.pga_targets = pga_targets
self.emb_dim = emb_dim
def forward(self, data):
cnn_output = self.model_CNN(
torch.DoubleTensor(
data["waveform"].reshape(-1, self.data_length, 3)).float().to(device)
)
cnn_output_reshape = torch.reshape(cnn_output,
(-1, self.max_station, self.emb_dim))
emb_output = self.model_Position(
torch.DoubleTensor(
data["station"].reshape(-1, 1, data["station"].shape[2])).float().to(
device)
)
emb_output = emb_output.reshape(-1, self.max_station, self.emb_dim)
station_pad_mask = data["station"] == 0
station_pad_mask = torch.all(station_pad_mask, 2)
pga_pos_emb_output = self.model_Position(
torch.DoubleTensor(
data["target"].reshape(-1, 1, data["target"].shape[2])).float().to(
device)
)
pga_pos_emb_output = pga_pos_emb_output.reshape(-1, self.pga_targets,
self.emb_dim)
target_pad_mask = torch.ones_like(data["target"], dtype=torch.bool)
target_pad_mask = torch.all(target_pad_mask, 2)
pad_mask = torch.cat((station_pad_mask, target_pad_mask), dim=1).to(device)
add_pe_cnn_output = torch.add(cnn_output_reshape, emb_output)
transformer_input = torch.cat((add_pe_cnn_output, pga_pos_emb_output), dim=1)
transformer_output = self.model_Transformer(transformer_input, pad_mask)
mlp_input = transformer_output[:, -self.pga_targets:, :].to(device)
mlp_output = self.model_mlp(mlp_input)
weight, sigma, mu = self.model_MDN(mlp_output)
return weight, sigma, mu
def get_full_model(model_path):
emb_dim = 150
mlp_dims = (150, 100, 50, 30, 10)
cnn_model = CNN(mlp_input=5665).to(device)
pos_emb_model = PositionEmbeddingVs30(emb_dim=emb_dim).to(device)
transformer_model = TransformerEncoder()
mlp_model = MLP(input_shape=(emb_dim,), dims=mlp_dims).to(device)
mdn_model = MDN(input_shape=(mlp_dims[-1],)).to(device)
full_model = FullModel(
cnn_model, pos_emb_model, transformer_model, mlp_model, mdn_model,
pga_targets=25, data_length=3000
).to(device)
full_model.load_state_dict(
torch.load(model_path, weights_only=True, map_location=device))
return full_model
# 載入模型
model_path = hf_hub_download(
repo_id="SeisBlue/TTSAM",
filename="ttsam_trained_model_11.pt"
)
model = get_full_model(model_path)
# ============ 輔助函數 ============
def lowpass(data, freq=10, df=100, corners=4):
fe = 0.5 * df
f = freq / fe
if f > 1:
f = 1.0
z, p, k = iirfilter(corners, f, btype="lowpass", ftype="butter", output="zpk")
sos = zpk2sos(z, p, k)
return sosfilt(sos, data)
def signal_processing(waveform):
data = detrend(waveform, type="constant")
data = lowpass(data, freq=10)
return data
def get_vs30(lat, lon, user_vs30=600):
if tree is None or vs30_table is None:
# 如果 Vs30 資料未載入,使用使用者輸入的值
logger.info(f"使用使用者輸入的 Vs30 值 ({user_vs30} m/s) for ({lat}, {lon})")
return float(user_vs30)
distance, i = tree.query([float(lat), float(lon)])
vs30 = vs30_table.iloc[i]["Vs30"]
logger.info(f"從資料庫查詢到 Vs30 值 ({vs30} m/s) for ({lat}, {lon})")
return float(vs30)
def calculate_intensity(pga, label=False):
intensity_label = ["0", "1", "2", "3", "4", "5-", "5+", "6-", "6+", "7"]
pga_level = np.log10([1e-5, 0.008, 0.025, 0.080, 0.250, 0.80, 1.4, 2.5, 4.4, 8.0])
pga_intensity = np.searchsorted(pga_level, pga) - 1
intensity = pga_intensity
if label:
return intensity_label[intensity]
else:
return intensity
# ============ Gradio 介面函數 ============
def load_waveform(event_name):
"""載入完整的 mseed 檔案(包含所有測站)"""
file_path = EARTHQUAKE_EVENTS[event_name]
st = read(file_path)
return st
def calculate_distance(lat1, lon1, lat2, lon2):
"""計算兩點間的距離(簡化的平面距離,單位:度)"""
return np.sqrt((lat1 - lat2)**2 + (lon1 - lon2)**2)
def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25):
"""從 site_info(1000+ 個輸入測站)中選擇距離震央最近的 n 個測站"""
station_distances = {} # 改用字典避免重複
# 計算每個測站到震央的距離
for tr in st:
station_code = tr.stats.station
# 如果這個測站已經處理過,跳過(避免重複計算不同分量)
if station_code in station_distances:
continue
# 從 site_info 中查詢測站位置
try:
station_data = site_info[site_info["Station"] == station_code]
if len(station_data) == 0:
continue
lat = station_data["Latitude"].values[0]
lon = station_data["Longitude"].values[0]
elev = station_data["Elevation"].values[0]
distance = calculate_distance(epicenter_lat, epicenter_lon, lat, lon)
station_distances[station_code] = {
"station": station_code,
"distance": distance,
"latitude": lat,
"longitude": lon,
"elevation": elev
}
except Exception as e:
logger.warning(f"測站 {station_code} 資訊查詢失敗: {e}")
continue
# 轉換為列表並按距離排序,選擇最近的 n 個
station_list = list(station_distances.values())
station_list.sort(key=lambda x: x["distance"])
selected_stations = station_list[:n_stations]
logger.info(f"從 {len(station_list)} 個輸入測站中選擇了最近的 {len(selected_stations)} 個")
return selected_stations
def extract_waveforms_from_stream(st, selected_stations, start_time, end_time, vs30_input):
"""從 Stream 中提取選定測站的波形資料"""
waveforms = []
station_info_list = []
valid_stations = []
sampling_rate = 100 # 假設 100 Hz
start_idx = int(start_time * sampling_rate)
end_idx = int(end_time * sampling_rate)
target_length = 3000
for station_data in selected_stations:
station_code = station_data["station"]
try:
# 選擇該測站的所有分量
st_station = st.select(station=station_code)
if len(st_station) == 0:
continue
# 嘗試取得 Z, N, E 分量
z_trace = st_station.select(component="Z")
n_trace = st_station.select(component="N") or st_station.select(component="1")
e_trace = st_station.select(component="E") or st_station.select(component="2")
# 如果沒有三分量,使用 Z 分量重複
if len(z_trace) > 0:
z_data = z_trace[0].data[start_idx:end_idx]
else:
continue
if len(n_trace) > 0:
n_data = n_trace[0].data[start_idx:end_idx]
else:
n_data = z_data.copy()
if len(e_trace) > 0:
e_data = e_trace[0].data[start_idx:end_idx]
else:
e_data = z_data.copy()
# 訊號處理
z_data = signal_processing(z_data)
n_data = signal_processing(n_data)
e_data = signal_processing(e_data)
# 先創建全零陣列 (3000, 3)
waveform_3c = np.zeros((target_length, 3))
# 填入實際資料(自動處理長度不足或過長的情況)
z_len = min(len(z_data), target_length)
n_len = min(len(n_data), target_length)
e_len = min(len(e_data), target_length)
waveform_3c[:z_len, 0] = z_data[:z_len]
waveform_3c[:n_len, 1] = n_data[:n_len]
waveform_3c[:e_len, 2] = e_data[:e_len]
waveforms.append(waveform_3c)
# 準備測站資訊
vs30 = get_vs30(station_data["latitude"], station_data["longitude"], vs30_input)
station_info_list.append([
station_data["latitude"],
station_data["longitude"],
station_data["elevation"],
vs30
])
valid_stations.append(station_data)
except Exception as e:
logger.warning(f"測站 {station_code} 波形提取失敗: {e}")
continue
logger.info(f"成功提取 {len(waveforms)} 個測站的波形")
return waveforms, station_info_list, valid_stations
def plot_waveform(st, selected_stations, start_time, end_time):
"""繪製選定測站的波形圖(距離-時間圖,可顯示全部 25 個測站)"""
fig, ax = plt.subplots(figsize=(14, 10))
# 設定振幅縮放比例(避免波形重疊)
amplitude_scale = 0.03 # 可調整此值來控制波形大小
plotted_count = 0
distances = []
station_names = []
for i, station_data in enumerate(selected_stations):
station_code = station_data["station"]
distance = station_data["distance"]
try:
st_station = st.select(station=station_code)
if len(st_station) > 0:
tr = st_station[0]
times = tr.times()
data = tr.data
# 正規化波形振幅
data_normalized = data / (np.max(np.abs(data)) + 1e-10)
# 繪製波形,Y軸位置為距離
ax.plot(times, distance + data_normalized * amplitude_scale,
'black', linewidth=0.3, alpha=0.8)
distances.append(distance)
station_names.append(station_code)
plotted_count += 1
except Exception as e:
logger.warning(f"無法繪製測站 {station_code}: {e}")
# 標記選取時間範圍
ax.axvline(start_time, color='red', linestyle='--', linewidth=2,
alpha=0.7, label='選取範圍')
ax.axvline(end_time, color='red', linestyle='--', linewidth=2, alpha=0.7)
ax.axvspan(start_time, end_time, alpha=0.15, color='blue')
# 設定軸標籤和標題
ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('Distance from Epicenter (°)', fontsize=12)
ax.set_title(f'Record Section - {plotted_count} Stations Sorted by Distance',
fontsize=14, fontweight='bold')
# 在右側標註測站名稱
if distances:
ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim())
ax2.set_ylabel('Station Code', fontsize=12)
# 每隔幾個測站標註一次(避免過於擁擠)
step = max(1, len(distances) // 10)
tick_positions = distances[::step]
tick_labels = station_names[::step]
ax2.set_yticks(tick_positions)
ax2.set_yticklabels(tick_labels, fontsize=8)
ax.grid(True, alpha=0.3, axis='x')
ax.legend(loc='upper right')
plt.tight_layout()
return fig
def get_intensity_color(intensity):
"""根據震度等級返回對應顏色(參考 intensityMap.html)"""
color_map = {
0: "#ffffff", # 白色
1: "#33FFDD", # 青色
2: "#34ff32", # 綠色
3: "#fefd32", # 黃色
4: "#fe8532", # 橘色
5: "#fd5233", # 紅橘色 (5-)
6: "#c43f3b", # 深紅色 (5+)
7: "#9d4646", # 暗紅色 (6-)
8: "#9a4c86", # 紫紅色 (6+)
9: "#b51fea", # 紫色 (7)
}
return color_map.get(intensity, "#ffffff")
def create_intensity_map(pga_list, target_names, epicenter_lat=None, epicenter_lon=None):
"""使用 Folium 創建互動式震度分布地圖"""
import folium
from folium import plugins
# 創建地圖,中心點設在台灣中心,設定地圖尺寸
m = folium.Map(
location=[23.5, 121],
zoom_start=7,
tiles='OpenStreetMap',
width='100%',
height='600px' # 設定固定高度,與 Ground Truth 圖片匹配
)
# 如果有震央位置,標記震央
if epicenter_lat and epicenter_lon:
folium.Marker(
[epicenter_lat, epicenter_lon],
popup=f'震央
({epicenter_lat:.3f}, {epicenter_lon:.3f})',
icon=folium.Icon(color='red', icon='star', prefix='fa'),
tooltip='震央位置'
).add_to(m)
# 添加震度測站標記
for i, target_name in enumerate(target_names):
target = next((t for t in target_dict if t["station"] == target_name), None)
if target:
lat = target["latitude"]
lon = target["longitude"]
intensity = calculate_intensity(pga_list[i])
intensity_label = calculate_intensity(pga_list[i], label=True)
color = get_intensity_color(intensity)
pga = pga_list[i]
# 創建 HTML popup 內容
popup_html = f"""
| 震度: | {intensity_label} |
| PGA: | {pga:.4f} m/s² |
| 位置: | ({lat:.3f}, {lon:.3f}) |
| {level} |
| 狀態: | ✓ 已選中 |
| 順序: | 第 {i+1} 近 |
| 距離: | {distance:.2f}° |
| 位置: | ({lat:.3f}, {lon:.3f}) |
★ 震央
● 所有測站 ({total_stations} 個)
被選中的測站:
● 前 5 近
● 6-15 近
● 16-25 近
共選擇 {len(selected_stations)} 個測站