Spaces:
Running
Running
Commit
·
2d4076c
1
Parent(s):
363136b
refactor site info loading and prediction batching for improved clarity and performance
Browse files
app.py
CHANGED
|
@@ -58,12 +58,14 @@ try:
|
|
| 58 |
except FileNotFoundError:
|
| 59 |
logger.error(f"{target_file} 找不到")
|
| 60 |
|
| 61 |
-
#
|
| 62 |
site_info_file = "station/site_info.csv"
|
| 63 |
try:
|
| 64 |
logger.info(f"載入 {site_info_file}...")
|
| 65 |
site_info = pd.read_csv(site_info_file)
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
except FileNotFoundError:
|
| 68 |
logger.warning(f"{site_info_file} 找不到")
|
| 69 |
|
|
@@ -415,7 +417,7 @@ def calculate_distance(lat1, lon1, lat2, lon2):
|
|
| 415 |
|
| 416 |
|
| 417 |
def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25):
|
| 418 |
-
"""
|
| 419 |
station_distances = []
|
| 420 |
|
| 421 |
# 計算每個測站到震央的距離
|
|
@@ -448,7 +450,7 @@ def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25):
|
|
| 448 |
station_distances.sort(key=lambda x: x["distance"])
|
| 449 |
selected_stations = station_distances[:n_stations]
|
| 450 |
|
| 451 |
-
logger.info(f"
|
| 452 |
return selected_stations
|
| 453 |
|
| 454 |
|
|
@@ -794,30 +796,60 @@ def predict_intensity(event_name, start_time, end_time, epicenter_lon, epicenter
|
|
| 794 |
waveform_padded[i] = waveforms[i]
|
| 795 |
station_info_padded[i] = station_info_list[i]
|
| 796 |
|
| 797 |
-
# 5.
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
"
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
|
| 822 |
# 8. 繪製互動式地圖
|
| 823 |
intensity_map = create_intensity_map(pga_list, target_names, epicenter_lat, epicenter_lon)
|
|
|
|
| 58 |
except FileNotFoundError:
|
| 59 |
logger.error(f"{target_file} 找不到")
|
| 60 |
|
| 61 |
+
# 載入測站資訊(輸入測站,1000+ 個)
|
| 62 |
site_info_file = "station/site_info.csv"
|
| 63 |
try:
|
| 64 |
logger.info(f"載入 {site_info_file}...")
|
| 65 |
site_info = pd.read_csv(site_info_file)
|
| 66 |
+
# 只保留唯一的測站(去除重複的分量)
|
| 67 |
+
site_info = site_info.drop_duplicates(subset=['Station']).reset_index(drop=True)
|
| 68 |
+
logger.info(f"{site_info_file} 載入完成,共 {len(site_info)} 個測站")
|
| 69 |
except FileNotFoundError:
|
| 70 |
logger.warning(f"{site_info_file} 找不到")
|
| 71 |
|
|
|
|
| 417 |
|
| 418 |
|
| 419 |
def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25):
|
| 420 |
+
"""從 site_info(1000+ 個輸入測站)中選擇距離震央最近的 n 個測站"""
|
| 421 |
station_distances = []
|
| 422 |
|
| 423 |
# 計算每個測站到震央的距離
|
|
|
|
| 450 |
station_distances.sort(key=lambda x: x["distance"])
|
| 451 |
selected_stations = station_distances[:n_stations]
|
| 452 |
|
| 453 |
+
logger.info(f"從 {len(station_distances)} 個輸入測站中選擇了最近的 {len(selected_stations)} 個")
|
| 454 |
return selected_stations
|
| 455 |
|
| 456 |
|
|
|
|
| 796 |
waveform_padded[i] = waveforms[i]
|
| 797 |
station_info_padded[i] = station_info_list[i]
|
| 798 |
|
| 799 |
+
# 5. 準備所有目標測站資訊(分批處理)
|
| 800 |
+
all_pga_list = []
|
| 801 |
+
all_target_names = []
|
| 802 |
+
|
| 803 |
+
# 計算需要分幾批(每批 25 個測站)
|
| 804 |
+
batch_size = 25
|
| 805 |
+
total_targets = len(target_dict)
|
| 806 |
+
num_batches = (total_targets + batch_size - 1) // batch_size
|
| 807 |
+
|
| 808 |
+
logger.info(f"開始分批預測 {total_targets} 個目標測站(共 {num_batches} 批)...")
|
| 809 |
+
|
| 810 |
+
for batch_idx in range(num_batches):
|
| 811 |
+
start_idx = batch_idx * batch_size
|
| 812 |
+
end_idx = min((batch_idx + 1) * batch_size, total_targets)
|
| 813 |
+
batch_targets = target_dict[start_idx:end_idx]
|
| 814 |
+
|
| 815 |
+
logger.info(f"預測第 {batch_idx + 1}/{num_batches} 批(測站 {start_idx + 1}-{end_idx})...")
|
| 816 |
+
|
| 817 |
+
# 準備這批目標測站資訊
|
| 818 |
+
target_list = []
|
| 819 |
+
target_names = []
|
| 820 |
+
for target in batch_targets:
|
| 821 |
+
target_list.append([
|
| 822 |
+
target["latitude"],
|
| 823 |
+
target["longitude"],
|
| 824 |
+
target["elevation"],
|
| 825 |
+
get_vs30(target["latitude"], target["longitude"], vs30_input)
|
| 826 |
+
])
|
| 827 |
+
target_names.append(target["station"])
|
| 828 |
+
|
| 829 |
+
# Padding 到 25 個(如果不足 25 個)
|
| 830 |
+
target_padded = np.zeros((batch_size, 4))
|
| 831 |
+
for i in range(len(target_list)):
|
| 832 |
+
target_padded[i] = target_list[i]
|
| 833 |
+
|
| 834 |
+
# 6. 組合成 tensor
|
| 835 |
+
tensor_data = {
|
| 836 |
+
"waveform": torch.tensor(waveform_padded).unsqueeze(0).double(),
|
| 837 |
+
"station": torch.tensor(station_info_padded).unsqueeze(0).double(),
|
| 838 |
+
"target": torch.tensor(target_padded).unsqueeze(0).double(),
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
# 7. 執行預測
|
| 842 |
+
with torch.no_grad():
|
| 843 |
+
weight, sigma, mu = model(tensor_data)
|
| 844 |
+
batch_pga = torch.sum(weight * mu, dim=2).cpu().detach().numpy().flatten().tolist()
|
| 845 |
+
|
| 846 |
+
# 只取實際有資料的部分
|
| 847 |
+
all_pga_list.extend(batch_pga[:len(target_names)])
|
| 848 |
+
all_target_names.extend(target_names)
|
| 849 |
+
|
| 850 |
+
logger.info(f"完成所有 {len(all_target_names)} 個測站的預測!")
|
| 851 |
+
pga_list = all_pga_list
|
| 852 |
+
target_names = all_target_names
|
| 853 |
|
| 854 |
# 8. 繪製互動式地圖
|
| 855 |
intensity_map = create_intensity_map(pga_list, target_names, epicenter_lat, epicenter_lon)
|