Ning311 commited on
Commit
47da25e
·
verified ·
1 Parent(s): 95d4581

Update utils/model_loader.py

Browse files
Files changed (1) hide show
  1. utils/model_loader.py +64 -39
utils/model_loader.py CHANGED
@@ -150,45 +150,61 @@ class ModelManager:
150
  sarcoI_cat_path = self.advisory_paths['sarcoI'] / "CatBoost_model.pkl"
151
  logger.info(f"使用本地SarcoI建议模型: {sarcoI_cat_path}")
152
 
153
- # SarcoI建议模型实际上是XGBoost模型(重命名为CatBoost_model.pkl)
154
- try:
155
- import xgboost as xgb
156
- # 尝试XGBoost加载方式
157
- try:
158
- self.advisory_models['sarcoI'] = xgb.XGBClassifier()
159
- self.advisory_models['sarcoI'].load_model(sarcoI_cat_path)
160
- logger.info("✅ SarcoI建议模型加载成功 (XGBoost格式)")
161
- except:
162
- # 如果XGBoost加载失败,尝试pickle加载
163
- with open(sarcoI_cat_path, 'rb') as f:
164
- loaded_model = pickle.load(f)
165
- logger.info(f"🔍 SarcoI建议模型类型: {type(loaded_model)}")
166
-
167
- # 检查是否是有效的机器学习模型
168
- if hasattr(loaded_model, 'predict_proba'):
169
- self.advisory_models['sarcoI'] = loaded_model
170
- logger.info("✅ SarcoI建议模型加载成功 (pickle格式)")
171
- else:
172
- # 尝试从字典中提取模型
173
- if isinstance(loaded_model, dict):
174
- for key in ['model', 'classifier', 'estimator']:
175
- if key in loaded_model and hasattr(loaded_model[key], 'predict_proba'):
176
- self.advisory_models['sarcoI'] = loaded_model[key]
177
- logger.info(f"✅ 从字典提取SarcoI建议模型成功 (键: {key})")
178
- break
179
- else:
180
- raise ValueError(f"无法从加载的对象中找到有效模型: {type(loaded_model)}")
181
- else:
182
- raise ValueError(f"加载的对象不是有效的机器学习模型: {type(loaded_model)}")
183
- except ImportError:
184
- logger.warning("XGBoost未安装,尝试pickle加载...")
185
- with open(sarcoI_cat_path, 'rb') as f:
186
- loaded_model = pickle.load(f)
187
- if hasattr(loaded_model, 'predict_proba'):
188
- self.advisory_models['sarcoI'] = loaded_model
189
- logger.info("✅ SarcoI建议模型加载成功 (pickle格式)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  else:
191
- raise ValueError(f"模型加载失败: {type(loaded_model)}")
192
 
193
  # SarcoII建议模型 (RandomForest)
194
  if self.use_hf_models and HF_HUB_AVAILABLE:
@@ -312,11 +328,20 @@ class ModelManager:
312
  def predict_advisory(self, user_data: Dict[str, float], model_type: str) -> Dict[str, Any]:
313
  """建议预测 - 高精确率"""
314
  try:
 
 
 
 
 
 
 
 
315
  # 准备特征数据
316
  features_df = self._prepare_features(user_data, model_type, mode='advisory')
317
-
318
  # 模型预测
319
  model = self.advisory_models[model_type]
 
320
  probability = model.predict_proba(features_df)[0, 1]
321
  threshold = self.thresholds[model_type]['advisory']
322
 
 
150
  sarcoI_cat_path = self.advisory_paths['sarcoI'] / "CatBoost_model.pkl"
151
  logger.info(f"使用本地SarcoI建议模型: {sarcoI_cat_path}")
152
 
153
+ # SarcoI建议模型是CatBoost模型 - 优先使用pickle加载
154
+ with open(sarcoI_cat_path, 'rb') as f:
155
+ loaded_model = pickle.load(f)
156
+ logger.info(f"🔍 SarcoI建议模型类型: {type(loaded_model)}")
157
+
158
+ # 检查是否是有效的机器学习模型
159
+ if hasattr(loaded_model, 'predict_proba'):
160
+ self.advisory_models['sarcoI'] = loaded_model
161
+ logger.info("✅ SarcoI建议模型加载成功 (pickle格式)")
162
+ else:
163
+ # 尝试从字典中提取模型
164
+ if isinstance(loaded_model, dict):
165
+ logger.info(f"🔍 字典键: {list(loaded_model.keys())}")
166
+
167
+ # 尝试常见的模型键,特别是CatBoost相关的键
168
+ model_keys = ['model', 'classifier', 'estimator', 'catboost_model', 'cb_model', 'best_model', 'trained_model', 'final_model', 'best_estimator']
169
+ found_model = False
170
+
171
+ for key in model_keys:
172
+ if key in loaded_model:
173
+ candidate_model = loaded_model[key]
174
+ logger.info(f"🔍 尝试键 '{key}': {type(candidate_model)}")
175
+
176
+ if hasattr(candidate_model, 'predict_proba'):
177
+ self.advisory_models['sarcoI'] = candidate_model
178
+ logger.info(f"✅ 从字典提取SarcoI建议模型成功 (键: {key})")
179
+ found_model = True
180
+ break
181
+
182
+ if not found_model:
183
+ # 如果没找到标准键,尝试所有值
184
+ for key, value in loaded_model.items():
185
+ logger.info(f"🔍 检查键 '{key}': {type(value)}")
186
+ if hasattr(value, 'predict_proba'):
187
+ self.advisory_models['sarcoI'] = value
188
+ logger.info(f"✅ 从字典提取SarcoI建议模型成功 (键: {key})")
189
+ found_model = True
190
+ break
191
+
192
+ if not found_model:
193
+ # 最后尝试:如果字典只有一个值,直接使���
194
+ if len(loaded_model) == 1:
195
+ single_key = list(loaded_model.keys())[0]
196
+ single_value = loaded_model[single_key]
197
+ logger.info(f"🔍 字典只有一个键 '{single_key}': {type(single_value)}")
198
+ if hasattr(single_value, 'predict_proba'):
199
+ self.advisory_models['sarcoI'] = single_value
200
+ logger.info(f"✅ 使用字典中唯一值作为SarcoI建议模型 (键: {single_key})")
201
+ found_model = True
202
+
203
+ if not found_model:
204
+ logger.error(f"❌ 字典内容详情: {[(k, type(v), hasattr(v, 'predict_proba') if hasattr(v, '__dict__') else 'N/A') for k, v in loaded_model.items()]}")
205
+ raise ValueError(f"字典中没有找到有效的机器学习模型")
206
  else:
207
+ raise ValueError(f"加载的对象不是有效的机器学习模型: {type(loaded_model)}")
208
 
209
  # SarcoII建议模型 (RandomForest)
210
  if self.use_hf_models and HF_HUB_AVAILABLE:
 
328
  def predict_advisory(self, user_data: Dict[str, float], model_type: str) -> Dict[str, Any]:
329
  """建议预测 - 高精确率"""
330
  try:
331
+ # 调试信息
332
+ logger.info(f"🔍 建议预测调试 - 模型类型: {model_type}")
333
+ logger.info(f"🔍 可用建议模型: {list(self.advisory_models.keys())}")
334
+
335
+ # 检查模型是否存在
336
+ if model_type not in self.advisory_models:
337
+ raise KeyError(f"建议模型 '{model_type}' 不存在,可用模型: {list(self.advisory_models.keys())}")
338
+
339
  # 准备特征数据
340
  features_df = self._prepare_features(user_data, model_type, mode='advisory')
341
+
342
  # 模型预测
343
  model = self.advisory_models[model_type]
344
+ logger.info(f"🔍 使用模型: {type(model)}")
345
  probability = model.predict_proba(features_df)[0, 1]
346
  threshold = self.thresholds[model_type]['advisory']
347