Ning311 commited on
Commit
0f48330
·
verified ·
1 Parent(s): ac47b9d

Update utils/model_loader.py

Browse files
Files changed (1) hide show
  1. utils/model_loader.py +164 -29
utils/model_loader.py CHANGED
@@ -7,11 +7,20 @@ import pickle
7
  import pandas as pd
8
  import numpy as np
9
  import logging
 
10
  from pathlib import Path
11
  from typing import Dict, Any, Optional
12
  import warnings
13
  warnings.filterwarnings('ignore')
14
 
 
 
 
 
 
 
 
 
15
  # 配置日志
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
@@ -24,20 +33,37 @@ class ModelManager:
24
  self.advisory_models = {}
25
  self.model_configs = {}
26
  self.thresholds = {}
27
-
28
  # 模型路径配置 - 支持本地和云端部署
29
  self.app_path = Path(__file__).parent.parent
30
-
31
- # 使用正确的模型路径
32
- self.base_path = Path("/Users/ning/Desktop/idea/代码forSarcoAdvisor")
33
- self.screening_paths = {
34
- 'sarcoI': self.base_path / "3.建模/SarcoI_results",
35
- 'sarcoII': self.base_path / "3.建模/SarcoII_results"
36
- }
37
- self.advisory_paths = {
38
- 'sarcoI': self.base_path / "4.DICE建模/SarcoI/individual_models",
39
- 'sarcoII': self.base_path / "4.DICE建模/SarcoII/individual_models"
40
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def load_all_models(self):
43
  """加载所有模型"""
@@ -53,19 +79,40 @@ class ModelManager:
53
  def _load_screening_models(self):
54
  """加载筛查类模型"""
55
  try:
56
- # SarcoI筛查模型 - RandomForest (更新后的最佳模型)
57
- sarcoI_rf_path = self.screening_paths['sarcoI'] / "randomforest_model.pkl"
 
 
 
 
 
 
 
 
 
 
58
  with open(sarcoI_rf_path, 'rb') as f:
59
  self.screening_models['sarcoI'] = pickle.load(f)
60
-
 
61
  # SarcoII筛查模型 - CatBoost (.cbm格式)
62
- sarcoII_cat_path = self.screening_paths['sarcoII'] / "catboost_model.cbm"
63
-
 
 
 
 
 
 
 
 
 
64
  # 需要特殊处理CatBoost模型加载
65
  try:
66
  import catboost as cb
67
  self.screening_models['sarcoII'] = cb.CatBoostClassifier()
68
  self.screening_models['sarcoII'].load_model(str(sarcoII_cat_path))
 
69
  except ImportError:
70
  logger.error("CatBoost未安装,无法加载SarcoII筛查模型")
71
  raise
@@ -80,17 +127,35 @@ class ModelManager:
80
  """加载建议类模型(高精确率)"""
81
  try:
82
  # SarcoI建议模型 (CatBoost)
83
- sarcoI_cat_path = self.advisory_paths['sarcoI'] / "CatBoost_model.pkl"
 
 
 
 
 
 
 
 
84
  with open(sarcoI_cat_path, 'rb') as f:
85
  self.advisory_models['sarcoI'] = pickle.load(f)
86
-
 
87
  # SarcoII建议模型 (RandomForest)
88
- sarcoII_rf_path = self.advisory_paths['sarcoII'] / "RandomForest_model.pkl"
 
 
 
 
 
 
 
 
89
  with open(sarcoII_rf_path, 'rb') as f:
90
  self.advisory_models['sarcoII'] = pickle.load(f)
91
-
 
92
  logger.info("建议模型加载成功")
93
-
94
  except Exception as e:
95
  logger.error(f"建议模型加载失败: {str(e)}")
96
  raise
@@ -234,10 +299,10 @@ class ModelManager:
234
  """准备模型特征 - 基于实际训练数据的特征顺序"""
235
  if model_type == 'sarcoI':
236
  if mode == 'screening':
237
- # SarcoI筛查模型特征 - 必须与训练时的特征顺序完全一致
238
- # 从调试结果得知模型期望的顺序:body_mass_index, age_years, WWI
239
  features = [
240
- 'body_mass_index', 'age_years', 'WWI'
241
  ]
242
  else: # advisory
243
  # SarcoI建议模型特征 (来自/Users/ning/Desktop/idea/代码forSarcoAdvisor/4.DICE建模/预筛选/SarcoI_train_final.csv)
@@ -247,10 +312,10 @@ class ModelManager:
247
  ]
248
  else: # sarcoII
249
  if mode == 'screening':
250
- # SarcoII筛查模型特征 (来自/Users/ning/Desktop/idea/模型forSarcoAdvisor/SarcoII/SarcoII_train_selected.csv)
251
- # 必须与训练时的特征顺序完全一致:age_years, race_ethnicity, body_mass_index, WWI
252
  features = [
253
- 'age_years', 'race_ethnicity', 'body_mass_index', 'WWI'
254
  ]
255
  else: # advisory
256
  # SarcoII建议模型特征 (来自/Users/ning/Desktop/idea/代码forSarcoAdvisor/4.DICE建模/预筛选/SarcoII_train_final.csv)
@@ -334,9 +399,79 @@ class ModelManager:
334
  else:
335
  return 'low'
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  def get_overall_risk(self, sarcoI_result: Dict, sarcoII_result: Dict) -> str:
338
  """
339
- 计算综合风险等级
340
 
341
  基于两个模型的预测结果,使用更科学的综合评估方法
342
  """
 
7
  import pandas as pd
8
  import numpy as np
9
  import logging
10
+ import os
11
  from pathlib import Path
12
  from typing import Dict, Any, Optional
13
  import warnings
14
  warnings.filterwarnings('ignore')
15
 
16
+ # 安全模型加载 - 从私有HF仓库加载
17
+ try:
18
+ from huggingface_hub import hf_hub_download
19
+ HF_HUB_AVAILABLE = True
20
+ except ImportError:
21
+ HF_HUB_AVAILABLE = False
22
+ print("⚠️ huggingface_hub未安装,将使用本地模型文件")
23
+
24
  # 配置日志
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
 
33
  self.advisory_models = {}
34
  self.model_configs = {}
35
  self.thresholds = {}
36
+
37
  # 模型路径配置 - 支持本地和云端部署
38
  self.app_path = Path(__file__).parent.parent
39
+
40
+ # 检查是否使用HF模型
41
+ self.use_hf_models = os.getenv("USE_HF_MODELS", "false").lower() == "true"
42
+ self.hf_model_repo = os.getenv("HF_MODEL_REPO", "Ning311/sarco-advisor-models")
43
+ self.hf_token = os.getenv("HF_TOKEN", None)
44
+
45
+ if self.use_hf_models and HF_HUB_AVAILABLE:
46
+ logger.info(f"🔒 使用HF私有仓库模型: {self.hf_model_repo}")
47
+ # HF模式下的模型路径
48
+ self.screening_paths = {
49
+ 'sarcoI': "models/screening/sarcoI",
50
+ 'sarcoII': "models/screening/sarcoII"
51
+ }
52
+ self.advisory_paths = {
53
+ 'sarcoI': "models/advisory/sarcoI",
54
+ 'sarcoII': "models/advisory/sarcoII"
55
+ }
56
+ else:
57
+ logger.info("📁 使用本地模型文件")
58
+ # 本地模式下的模型路径
59
+ self.screening_paths = {
60
+ 'sarcoI': self.app_path / "models/screening/sarcoI",
61
+ 'sarcoII': self.app_path / "models/screening/sarcoII"
62
+ }
63
+ self.advisory_paths = {
64
+ 'sarcoI': self.app_path / "models/advisory/sarcoI",
65
+ 'sarcoII': self.app_path / "models/advisory/sarcoII"
66
+ }
67
 
68
  def load_all_models(self):
69
  """加载所有模型"""
 
79
  def _load_screening_models(self):
80
  """加载筛查类模型"""
81
  try:
82
+ # SarcoI筛查模型 - RandomForest
83
+ if self.use_hf_models and HF_HUB_AVAILABLE:
84
+ # 从HF下载模型
85
+ sarcoI_rf_path = hf_hub_download(
86
+ repo_id=self.hf_model_repo,
87
+ filename=f"{self.screening_paths['sarcoI']}/randomforest_model.pkl",
88
+ token=self.hf_token
89
+ )
90
+ else:
91
+ # 使用本地模型
92
+ sarcoI_rf_path = self.screening_paths['sarcoI'] / "randomforest_model.pkl"
93
+
94
  with open(sarcoI_rf_path, 'rb') as f:
95
  self.screening_models['sarcoI'] = pickle.load(f)
96
+ logger.info("✅ SarcoI筛查模型加载成功")
97
+
98
  # SarcoII筛查模型 - CatBoost (.cbm格式)
99
+ if self.use_hf_models and HF_HUB_AVAILABLE:
100
+ # 从HF下载模型
101
+ sarcoII_cat_path = hf_hub_download(
102
+ repo_id=self.hf_model_repo,
103
+ filename=f"{self.screening_paths['sarcoII']}/catboost_model.cbm",
104
+ token=self.hf_token
105
+ )
106
+ else:
107
+ # 使用本地模型
108
+ sarcoII_cat_path = self.screening_paths['sarcoII'] / "catboost_model.cbm"
109
+
110
  # 需要特殊处理CatBoost模型加载
111
  try:
112
  import catboost as cb
113
  self.screening_models['sarcoII'] = cb.CatBoostClassifier()
114
  self.screening_models['sarcoII'].load_model(str(sarcoII_cat_path))
115
+ logger.info("✅ SarcoII筛查模型加载成功")
116
  except ImportError:
117
  logger.error("CatBoost未安装,无法加载SarcoII筛查模型")
118
  raise
 
127
  """加载建议类模型(高精确率)"""
128
  try:
129
  # SarcoI建议模型 (CatBoost)
130
+ if self.use_hf_models and HF_HUB_AVAILABLE:
131
+ sarcoI_cat_path = hf_hub_download(
132
+ repo_id=self.hf_model_repo,
133
+ filename=f"{self.advisory_paths['sarcoI']}/CatBoost_model.pkl",
134
+ token=self.hf_token
135
+ )
136
+ else:
137
+ sarcoI_cat_path = self.advisory_paths['sarcoI'] / "CatBoost_model.pkl"
138
+
139
  with open(sarcoI_cat_path, 'rb') as f:
140
  self.advisory_models['sarcoI'] = pickle.load(f)
141
+ logger.info("✅ SarcoI建议模型加载成功")
142
+
143
  # SarcoII建议模型 (RandomForest)
144
+ if self.use_hf_models and HF_HUB_AVAILABLE:
145
+ sarcoII_rf_path = hf_hub_download(
146
+ repo_id=self.hf_model_repo,
147
+ filename=f"{self.advisory_paths['sarcoII']}/RandomForest_model.pkl",
148
+ token=self.hf_token
149
+ )
150
+ else:
151
+ sarcoII_rf_path = self.advisory_paths['sarcoII'] / "RandomForest_model.pkl"
152
+
153
  with open(sarcoII_rf_path, 'rb') as f:
154
  self.advisory_models['sarcoII'] = pickle.load(f)
155
+ logger.info("✅ SarcoII建议模型加载成功")
156
+
157
  logger.info("建议模型加载成功")
158
+
159
  except Exception as e:
160
  logger.error(f"建议模型加载失败: {str(e)}")
161
  raise
 
299
  """准备模型特征 - 基于实际训练数据的特征顺序"""
300
  if model_type == 'sarcoI':
301
  if mode == 'screening':
302
+ # SarcoI筛查模型特征 - 基于实际模型期望的特征顺序
303
+ # 模型期望:['age_years', 'WWI', 'body_mass_index']
304
  features = [
305
+ 'age_years', 'WWI', 'body_mass_index'
306
  ]
307
  else: # advisory
308
  # SarcoI建议模型特征 (来自/Users/ning/Desktop/idea/代码forSarcoAdvisor/4.DICE建模/预筛选/SarcoI_train_final.csv)
 
312
  ]
313
  else: # sarcoII
314
  if mode == 'screening':
315
+ # SarcoII筛查模型特征 - 基于实际模型期望的特征顺序
316
+ # 模型期望:['age_years', 'race_ethnicity', 'WWI', 'body_mass_index']
317
  features = [
318
+ 'age_years', 'race_ethnicity', 'WWI', 'body_mass_index'
319
  ]
320
  else: # advisory
321
  # SarcoII建议模型特征 (来自/Users/ning/Desktop/idea/代码forSarcoAdvisor/4.DICE建模/预筛选/SarcoII_train_final.csv)
 
399
  else:
400
  return 'low'
401
 
402
+ def get_comprehensive_risk(self, sarcoI_screening_result: Dict, sarcoI_advisory_result: Dict = None,
403
+ sarcoII_screening_result: Dict = None, sarcoII_advisory_result: Dict = None) -> Dict:
404
+ """
405
+ 计算新的综合风险等级 - 基于建议模型优先的融合方案
406
+
407
+ Args:
408
+ sarcoI_screening_result: SarcoI筛查模型结果
409
+ sarcoI_advisory_result: SarcoI建议模型结果 (可选)
410
+ sarcoII_screening_result: SarcoII筛查模型结果 (可选)
411
+ sarcoII_advisory_result: SarcoII建议模型结果 (可选)
412
+
413
+ Returns:
414
+ Dict: 包含SarcoI和SarcoII综合风险的字典
415
+ """
416
+ results = {}
417
+
418
+ # SarcoI 综合风险判定
419
+ if sarcoI_screening_result:
420
+ P_recall_I = sarcoI_screening_result['probability']
421
+ P_precision_I = sarcoI_advisory_result['probability'] if sarcoI_advisory_result else 0.0
422
+
423
+ # 使用实际的模型阈值
424
+ sarcoI_advisory_threshold = self.thresholds['sarcoI']['advisory']
425
+ sarcoI_screening_threshold = self.thresholds['sarcoI']['screening']
426
+
427
+ if P_precision_I >= sarcoI_advisory_threshold: # 建议模型高风险阈值
428
+ sarcoI_comprehensive_risk = "high"
429
+ sarcoI_risk_reason = "advisory_model_high_risk"
430
+ elif P_recall_I >= sarcoI_screening_threshold: # 筛查模型高风险阈值
431
+ sarcoI_comprehensive_risk = "medium"
432
+ sarcoI_risk_reason = "screening_model_risk"
433
+ else:
434
+ sarcoI_comprehensive_risk = "low"
435
+ sarcoI_risk_reason = "both_models_low_risk"
436
+
437
+ results['sarcoI'] = {
438
+ 'comprehensive_risk': sarcoI_comprehensive_risk,
439
+ 'screening_probability': P_recall_I,
440
+ 'advisory_probability': P_precision_I,
441
+ 'risk_reason': sarcoI_risk_reason
442
+ }
443
+
444
+ # SarcoII 综合风险判定
445
+ if sarcoII_screening_result:
446
+ P_recall_II = sarcoII_screening_result['probability']
447
+ P_precision_II = sarcoII_advisory_result['probability'] if sarcoII_advisory_result else 0.0
448
+
449
+ # 使用实际的模型阈值
450
+ sarcoII_advisory_threshold = self.thresholds['sarcoII']['advisory']
451
+ sarcoII_screening_threshold = self.thresholds['sarcoII']['screening']
452
+
453
+ if P_precision_II >= sarcoII_advisory_threshold: # 建议模型高风险阈值
454
+ sarcoII_comprehensive_risk = "high"
455
+ sarcoII_risk_reason = "advisory_model_high_risk"
456
+ elif P_recall_II >= sarcoII_screening_threshold: # 筛查模型高风险阈值
457
+ sarcoII_comprehensive_risk = "medium"
458
+ sarcoII_risk_reason = "screening_model_risk"
459
+ else:
460
+ sarcoII_comprehensive_risk = "low"
461
+ sarcoII_risk_reason = "both_models_low_risk"
462
+
463
+ results['sarcoII'] = {
464
+ 'comprehensive_risk': sarcoII_comprehensive_risk,
465
+ 'screening_probability': P_recall_II,
466
+ 'advisory_probability': P_precision_II,
467
+ 'risk_reason': sarcoII_risk_reason
468
+ }
469
+
470
+ return results
471
+
472
  def get_overall_risk(self, sarcoI_result: Dict, sarcoII_result: Dict) -> str:
473
  """
474
+ 计算综合风险等级 (保持向后兼容)
475
 
476
  基于两个模型的预测结果,使用更科学的综合评估方法
477
  """