jimmy60504 commited on
Commit
2d4076c
·
1 Parent(s): 363136b

refactor site info loading and prediction batching for improved clarity and performance

Browse files
Files changed (1) hide show
  1. app.py +60 -28
app.py CHANGED
@@ -58,12 +58,14 @@ try:
58
  except FileNotFoundError:
59
  logger.error(f"{target_file} 找不到")
60
 
61
- # 載入測站資訊
62
  site_info_file = "station/site_info.csv"
63
  try:
64
  logger.info(f"載入 {site_info_file}...")
65
  site_info = pd.read_csv(site_info_file)
66
- logger.info(f"{site_info_file} 載入完成")
 
 
67
  except FileNotFoundError:
68
  logger.warning(f"{site_info_file} 找不到")
69
 
@@ -415,7 +417,7 @@ def calculate_distance(lat1, lon1, lat2, lon2):
415
 
416
 
417
  def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25):
418
- """選擇距離震央最近的 n 個測站"""
419
  station_distances = []
420
 
421
  # 計算每個測站到震央的距離
@@ -448,7 +450,7 @@ def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25):
448
  station_distances.sort(key=lambda x: x["distance"])
449
  selected_stations = station_distances[:n_stations]
450
 
451
- logger.info(f"選擇了 {len(selected_stations)} 個最近的測站")
452
  return selected_stations
453
 
454
 
@@ -794,30 +796,60 @@ def predict_intensity(event_name, start_time, end_time, epicenter_lon, epicenter
794
  waveform_padded[i] = waveforms[i]
795
  station_info_padded[i] = station_info_list[i]
796
 
797
- # 5. 準備目標測站資訊
798
- target_list = []
799
- target_names = []
800
- for target in target_dict[:25]:
801
- target_list.append([
802
- target["latitude"],
803
- target["longitude"],
804
- target["elevation"],
805
- get_vs30(target["latitude"], target["longitude"], vs30_input)
806
- ])
807
- target_names.append(target["station"])
808
-
809
- # 6. 組合成 tensor
810
- tensor_data = {
811
- "waveform": torch.tensor(waveform_padded).unsqueeze(0).double(),
812
- "station": torch.tensor(station_info_padded).unsqueeze(0).double(),
813
- "target": torch.tensor(target_list).unsqueeze(0).double(),
814
- }
815
-
816
- # 7. 執行預測
817
- logger.info("執行模型預測...")
818
- with torch.no_grad():
819
- weight, sigma, mu = model(tensor_data)
820
- pga_list = torch.sum(weight * mu, dim=2).cpu().detach().numpy().flatten().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
821
 
822
  # 8. 繪製互動式地圖
823
  intensity_map = create_intensity_map(pga_list, target_names, epicenter_lat, epicenter_lon)
 
58
  except FileNotFoundError:
59
  logger.error(f"{target_file} 找不到")
60
 
61
+ # 載入測站資訊(輸入測站,1000+ 個)
62
  site_info_file = "station/site_info.csv"
63
  try:
64
  logger.info(f"載入 {site_info_file}...")
65
  site_info = pd.read_csv(site_info_file)
66
+ # 只保留唯一的測站(去除重複的分量)
67
+ site_info = site_info.drop_duplicates(subset=['Station']).reset_index(drop=True)
68
+ logger.info(f"{site_info_file} 載入完成,共 {len(site_info)} 個測站")
69
  except FileNotFoundError:
70
  logger.warning(f"{site_info_file} 找不到")
71
 
 
417
 
418
 
419
  def select_nearest_stations(st, epicenter_lat, epicenter_lon, n_stations=25):
420
+ """ site_info(1000+ 個輸入測站)中選擇距離震央最近的 n 個測站"""
421
  station_distances = []
422
 
423
  # 計算每個測站到震央的距離
 
450
  station_distances.sort(key=lambda x: x["distance"])
451
  selected_stations = station_distances[:n_stations]
452
 
453
+ logger.info(f" {len(station_distances)} 個輸入測站中選擇了最近的 {len(selected_stations)} ")
454
  return selected_stations
455
 
456
 
 
796
  waveform_padded[i] = waveforms[i]
797
  station_info_padded[i] = station_info_list[i]
798
 
799
+ # 5. 準備所有目標測站資訊(分批處理)
800
+ all_pga_list = []
801
+ all_target_names = []
802
+
803
+ # 計算需要分幾批(每批 25 個測站)
804
+ batch_size = 25
805
+ total_targets = len(target_dict)
806
+ num_batches = (total_targets + batch_size - 1) // batch_size
807
+
808
+ logger.info(f"開始分批預測 {total_targets} 個目標測站(共 {num_batches} 批)...")
809
+
810
+ for batch_idx in range(num_batches):
811
+ start_idx = batch_idx * batch_size
812
+ end_idx = min((batch_idx + 1) * batch_size, total_targets)
813
+ batch_targets = target_dict[start_idx:end_idx]
814
+
815
+ logger.info(f"預測第 {batch_idx + 1}/{num_batches} 批(測站 {start_idx + 1}-{end_idx})...")
816
+
817
+ # 準備這批目標測站資訊
818
+ target_list = []
819
+ target_names = []
820
+ for target in batch_targets:
821
+ target_list.append([
822
+ target["latitude"],
823
+ target["longitude"],
824
+ target["elevation"],
825
+ get_vs30(target["latitude"], target["longitude"], vs30_input)
826
+ ])
827
+ target_names.append(target["station"])
828
+
829
+ # Padding 到 25 個(如果不足 25 個)
830
+ target_padded = np.zeros((batch_size, 4))
831
+ for i in range(len(target_list)):
832
+ target_padded[i] = target_list[i]
833
+
834
+ # 6. 組合成 tensor
835
+ tensor_data = {
836
+ "waveform": torch.tensor(waveform_padded).unsqueeze(0).double(),
837
+ "station": torch.tensor(station_info_padded).unsqueeze(0).double(),
838
+ "target": torch.tensor(target_padded).unsqueeze(0).double(),
839
+ }
840
+
841
+ # 7. 執行預測
842
+ with torch.no_grad():
843
+ weight, sigma, mu = model(tensor_data)
844
+ batch_pga = torch.sum(weight * mu, dim=2).cpu().detach().numpy().flatten().tolist()
845
+
846
+ # 只取實際有資料的部分
847
+ all_pga_list.extend(batch_pga[:len(target_names)])
848
+ all_target_names.extend(target_names)
849
+
850
+ logger.info(f"完成所有 {len(all_target_names)} 個測站的預測!")
851
+ pga_list = all_pga_list
852
+ target_names = all_target_names
853
 
854
  # 8. 繪製互動式地圖
855
  intensity_map = create_intensity_map(pga_list, target_names, epicenter_lat, epicenter_lon)