Spaces:
Sleeping
Sleeping
Update utils/model_loader.py
Browse files- utils/model_loader.py +64 -39
utils/model_loader.py
CHANGED
|
@@ -150,45 +150,61 @@ class ModelManager:
|
|
| 150 |
sarcoI_cat_path = self.advisory_paths['sarcoI'] / "CatBoost_model.pkl"
|
| 151 |
logger.info(f"使用本地SarcoI建议模型: {sarcoI_cat_path}")
|
| 152 |
|
| 153 |
-
# SarcoI
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
logger.info(f"🔍
|
| 166 |
-
|
| 167 |
-
#
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
else:
|
| 191 |
-
raise ValueError(f"
|
| 192 |
|
| 193 |
# SarcoII建议模型 (RandomForest)
|
| 194 |
if self.use_hf_models and HF_HUB_AVAILABLE:
|
|
@@ -312,11 +328,20 @@ class ModelManager:
|
|
| 312 |
def predict_advisory(self, user_data: Dict[str, float], model_type: str) -> Dict[str, Any]:
|
| 313 |
"""建议预测 - 高精确率"""
|
| 314 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
# 准备特征数据
|
| 316 |
features_df = self._prepare_features(user_data, model_type, mode='advisory')
|
| 317 |
-
|
| 318 |
# 模型预测
|
| 319 |
model = self.advisory_models[model_type]
|
|
|
|
| 320 |
probability = model.predict_proba(features_df)[0, 1]
|
| 321 |
threshold = self.thresholds[model_type]['advisory']
|
| 322 |
|
|
|
|
| 150 |
sarcoI_cat_path = self.advisory_paths['sarcoI'] / "CatBoost_model.pkl"
|
| 151 |
logger.info(f"使用本地SarcoI建议模型: {sarcoI_cat_path}")
|
| 152 |
|
| 153 |
+
# SarcoI建议模型是CatBoost模型 - 优先使用pickle加载
|
| 154 |
+
with open(sarcoI_cat_path, 'rb') as f:
|
| 155 |
+
loaded_model = pickle.load(f)
|
| 156 |
+
logger.info(f"🔍 SarcoI建议模型类型: {type(loaded_model)}")
|
| 157 |
+
|
| 158 |
+
# 检查是否是有效的机器学习模型
|
| 159 |
+
if hasattr(loaded_model, 'predict_proba'):
|
| 160 |
+
self.advisory_models['sarcoI'] = loaded_model
|
| 161 |
+
logger.info("✅ SarcoI建议模型加载成功 (pickle格式)")
|
| 162 |
+
else:
|
| 163 |
+
# 尝试从字典中提取模型
|
| 164 |
+
if isinstance(loaded_model, dict):
|
| 165 |
+
logger.info(f"🔍 字典键: {list(loaded_model.keys())}")
|
| 166 |
+
|
| 167 |
+
# 尝试常见的模型键,特别是CatBoost相关的键
|
| 168 |
+
model_keys = ['model', 'classifier', 'estimator', 'catboost_model', 'cb_model', 'best_model', 'trained_model', 'final_model', 'best_estimator']
|
| 169 |
+
found_model = False
|
| 170 |
+
|
| 171 |
+
for key in model_keys:
|
| 172 |
+
if key in loaded_model:
|
| 173 |
+
candidate_model = loaded_model[key]
|
| 174 |
+
logger.info(f"🔍 尝试键 '{key}': {type(candidate_model)}")
|
| 175 |
+
|
| 176 |
+
if hasattr(candidate_model, 'predict_proba'):
|
| 177 |
+
self.advisory_models['sarcoI'] = candidate_model
|
| 178 |
+
logger.info(f"✅ 从字典提取SarcoI建议模型成功 (键: {key})")
|
| 179 |
+
found_model = True
|
| 180 |
+
break
|
| 181 |
+
|
| 182 |
+
if not found_model:
|
| 183 |
+
# 如果没找到标准键,尝试所有值
|
| 184 |
+
for key, value in loaded_model.items():
|
| 185 |
+
logger.info(f"🔍 检查键 '{key}': {type(value)}")
|
| 186 |
+
if hasattr(value, 'predict_proba'):
|
| 187 |
+
self.advisory_models['sarcoI'] = value
|
| 188 |
+
logger.info(f"✅ 从字典提取SarcoI建议模型成功 (键: {key})")
|
| 189 |
+
found_model = True
|
| 190 |
+
break
|
| 191 |
+
|
| 192 |
+
if not found_model:
|
| 193 |
+
# 最后尝试:如果字典只有一个值,直接使���
|
| 194 |
+
if len(loaded_model) == 1:
|
| 195 |
+
single_key = list(loaded_model.keys())[0]
|
| 196 |
+
single_value = loaded_model[single_key]
|
| 197 |
+
logger.info(f"🔍 字典只有一个键 '{single_key}': {type(single_value)}")
|
| 198 |
+
if hasattr(single_value, 'predict_proba'):
|
| 199 |
+
self.advisory_models['sarcoI'] = single_value
|
| 200 |
+
logger.info(f"✅ 使用字典中唯一值作为SarcoI建议模型 (键: {single_key})")
|
| 201 |
+
found_model = True
|
| 202 |
+
|
| 203 |
+
if not found_model:
|
| 204 |
+
logger.error(f"❌ 字典内容详情: {[(k, type(v), hasattr(v, 'predict_proba') if hasattr(v, '__dict__') else 'N/A') for k, v in loaded_model.items()]}")
|
| 205 |
+
raise ValueError(f"字典中没有找到有效的机器学习模型")
|
| 206 |
else:
|
| 207 |
+
raise ValueError(f"加载的对象不是有效的机器学习模型: {type(loaded_model)}")
|
| 208 |
|
| 209 |
# SarcoII建议模型 (RandomForest)
|
| 210 |
if self.use_hf_models and HF_HUB_AVAILABLE:
|
|
|
|
| 328 |
def predict_advisory(self, user_data: Dict[str, float], model_type: str) -> Dict[str, Any]:
|
| 329 |
"""建议预测 - 高精确率"""
|
| 330 |
try:
|
| 331 |
+
# 调试信息
|
| 332 |
+
logger.info(f"🔍 建议预测调试 - 模型类型: {model_type}")
|
| 333 |
+
logger.info(f"🔍 可用建议模型: {list(self.advisory_models.keys())}")
|
| 334 |
+
|
| 335 |
+
# 检查模型是否存在
|
| 336 |
+
if model_type not in self.advisory_models:
|
| 337 |
+
raise KeyError(f"建议模型 '{model_type}' 不存在,可用模型: {list(self.advisory_models.keys())}")
|
| 338 |
+
|
| 339 |
# 准备特征数据
|
| 340 |
features_df = self._prepare_features(user_data, model_type, mode='advisory')
|
| 341 |
+
|
| 342 |
# 模型预测
|
| 343 |
model = self.advisory_models[model_type]
|
| 344 |
+
logger.info(f"🔍 使用模型: {type(model)}")
|
| 345 |
probability = model.predict_proba(features_df)[0, 1]
|
| 346 |
threshold = self.thresholds[model_type]['advisory']
|
| 347 |
|