Spaces:
Sleeping
Sleeping
| """ | |
| 安全模型加载器 - 从私有HF仓库加载模型 | |
| 用于公开Space但保护模型文件 | |
| """ | |
| import pickle | |
| import pandas as pd | |
| import numpy as np | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # 安全模型加载 - 从私有HF仓库加载 | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| HF_HUB_AVAILABLE = True | |
| except ImportError: | |
| HF_HUB_AVAILABLE = False | |
| print("⚠️ huggingface_hub未安装,将使用本地模型文件") | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class SecureModelManager: | |
| """安全模型管理器 - 从私有HF仓库加载模型""" | |
| def __init__(self): | |
| """初始化安全模型管理器""" | |
| self.screening_models = {} | |
| self.advisory_models = {} | |
| self.thresholds = {} | |
| # HF私有仓库配置 | |
| self.hf_repo_id = os.getenv("HF_MODEL_REPO", "YOUR_USERNAME/sarco-advisor-models") | |
| self.hf_token = os.getenv("HF_TOKEN") | |
| self.use_hf_models = HF_HUB_AVAILABLE and self.hf_token | |
| if self.use_hf_models: | |
| logger.info(f"🔒 使用HF私有仓库加载模型: {self.hf_repo_id}") | |
| else: | |
| logger.info("📁 回退到本地模型文件") | |
| # 导入原始模型管理器作为备用 | |
| from .model_loader import ModelManager | |
| self.fallback_manager = ModelManager() | |
| # 加载所有模型 | |
| self.load_all_models() | |
| def load_model_from_hf(self, model_path: str): | |
| """从HF私有仓库加载模型""" | |
| try: | |
| # 下载模型文件到临时位置 | |
| local_path = hf_hub_download( | |
| repo_id=self.hf_repo_id, | |
| filename=model_path, | |
| token=self.hf_token, | |
| cache_dir="/tmp/hf_models" # 临时缓存,不会被下载 | |
| ) | |
| # 加载模型 | |
| with open(local_path, 'rb') as f: | |
| model = pickle.load(f) | |
| logger.info(f"✅ 从HF仓库加载模型: {model_path}") | |
| return model | |
| except Exception as e: | |
| logger.error(f"❌ HF模型加载失败 {model_path}: {str(e)}") | |
| return None | |
| def load_all_models(self): | |
| """加载所有模型""" | |
| if self.use_hf_models: | |
| self._load_models_from_hf() | |
| else: | |
| self._load_models_locally() | |
| def _load_models_from_hf(self): | |
| """从HF私有仓库加载所有模型""" | |
| logger.info("🔒 从HF私有仓库加载模型...") | |
| # 模型文件映射 | |
| model_files = { | |
| # 筛查模型 | |
| 'sarcoI_screening': 'models/screening/sarcoI/randomforest_model.pkl', | |
| # 建议模型 | |
| 'sarcoI_advisory': 'models/advisory/sarcoI/CatBoost_model.pkl', | |
| 'sarcoII_advisory': 'models/advisory/sarcoII/RandomForest_model.pkl' | |
| } | |
| # 阈值文件 | |
| threshold_files = { | |
| 'sarcoI_screening': 'models/screening/sarcoI/optimization_results.pkl', | |
| 'sarcoII_screening': 'models/screening/sarcoII/optimization_results.pkl' | |
| } | |
| # 加载模型 | |
| for model_name, model_path in model_files.items(): | |
| model = self.load_model_from_hf(model_path) | |
| if model: | |
| if 'screening' in model_name: | |
| model_type = model_name.replace('_screening', '') | |
| self.screening_models[model_type] = model | |
| elif 'advisory' in model_name: | |
| model_type = model_name.replace('_advisory', '') | |
| self.advisory_models[model_type] = model | |
| # 加载阈值 | |
| for threshold_name, threshold_path in threshold_files.items(): | |
| threshold_data = self.load_model_from_hf(threshold_path) | |
| if threshold_data: | |
| model_type = threshold_name.replace('_screening', '') | |
| # 解析阈值数据 | |
| if model_type == 'sarcoI': | |
| if 'rf_best_threshold' in threshold_data: | |
| self.thresholds[model_type] = { | |
| 'screening': threshold_data['rf_best_threshold'], | |
| 'advisory': 0.36 # 默认建议模型阈值 | |
| } | |
| elif model_type == 'sarcoII': | |
| if 'catboost_best_threshold' in threshold_data: | |
| self.thresholds[model_type] = { | |
| 'screening': threshold_data['catboost_best_threshold'], | |
| 'advisory': 0.52 # 默认建议模型阈值 | |
| } | |
| logger.info("✅ HF模型加载完成") | |
| def _load_models_locally(self): | |
| """回退到本地模型加载""" | |
| logger.info("📁 使用本地模型文件...") | |
| if hasattr(self, 'fallback_manager'): | |
| self.fallback_manager.load_all_models() | |
| # 复制模型和阈值 | |
| self.screening_models = self.fallback_manager.screening_models | |
| self.advisory_models = self.fallback_manager.advisory_models | |
| self.thresholds = self.fallback_manager.thresholds | |
| logger.info("✅ 本地模型加载完成") | |
| def predict_screening(self, user_data: Dict, model_type: str) -> Dict: | |
| """筛查预测""" | |
| if hasattr(self, 'fallback_manager') and not self.use_hf_models: | |
| return self.fallback_manager.predict_screening(user_data, model_type) | |
| # HF模型预测逻辑 | |
| if model_type not in self.screening_models: | |
| raise ValueError(f"筛查模型 {model_type} 未找到") | |
| model = self.screening_models[model_type] | |
| threshold = self.thresholds.get(model_type, {}).get('screening', 0.5) | |
| # 准备特征数据 | |
| if model_type == 'sarcoI': | |
| features = ['age_years', 'WWI', 'body_mass_index'] | |
| else: | |
| features = ['age_years', 'race_ethnicity', 'WWI', 'body_mass_index'] | |
| X = np.array([[user_data[f] for f in features]]) | |
| # 预测 | |
| probability = model.predict_proba(X)[0][1] | |
| risk_level = 'high' if probability >= threshold else 'low' | |
| return { | |
| 'probability': probability, | |
| 'risk_level': risk_level, | |
| 'threshold': threshold, | |
| 'model_type': model_type | |
| } | |
| def predict_advisory(self, user_data: Dict, model_type: str) -> Dict: | |
| """建议预测""" | |
| if hasattr(self, 'fallback_manager') and not self.use_hf_models: | |
| return self.fallback_manager.predict_advisory(user_data, model_type) | |
| # HF模型预测逻辑 | |
| if model_type not in self.advisory_models: | |
| raise ValueError(f"建议模型 {model_type} 未找到") | |
| model = self.advisory_models[model_type] | |
| threshold = self.thresholds.get(model_type, {}).get('advisory', 0.5) | |
| # 准备特征数据(简化版本) | |
| if model_type == 'sarcoI': | |
| features = ['body_mass_index', 'race_ethnicity', 'WWI', 'age_years', | |
| 'Activity_Sedentary_Ratio', 'Total_Moderate_Minutes_week', 'Vigorous_MET_Ratio'] | |
| else: | |
| features = ['body_mass_index', 'race_ethnicity', 'age_years', 'Activity_Sedentary_Ratio', | |
| 'Activity_Diversity_Index', 'WWI', 'Vigorous_MET_Ratio', 'sedentary_minutes'] | |
| # 检查特征是否存在 | |
| available_features = [] | |
| for f in features: | |
| if f in user_data: | |
| available_features.append(user_data[f]) | |
| else: | |
| available_features.append(0.0) # 默认值 | |
| X = np.array([available_features]) | |
| # 预测 | |
| probability = model.predict_proba(X)[0][1] | |
| risk_level = 'high' if probability >= threshold else 'low' | |
| return { | |
| 'probability': probability, | |
| 'risk_level': risk_level, | |
| 'threshold': threshold, | |
| 'model_type': model_type | |
| } | |
| def get_comprehensive_risk(self, sarcoI_screening_result: Dict, sarcoI_advisory_result: Dict = None, | |
| sarcoII_screening_result: Dict = None, sarcoII_advisory_result: Dict = None) -> Dict: | |
| """综合风险评估""" | |
| if hasattr(self, 'fallback_manager') and not self.use_hf_models: | |
| return self.fallback_manager.get_comprehensive_risk( | |
| sarcoI_screening_result, sarcoI_advisory_result, | |
| sarcoII_screening_result, sarcoII_advisory_result | |
| ) | |
| # 使用与原始模型管理器相同的逻辑 | |
| results = {} | |
| # SarcoI 综合风险判定 | |
| if sarcoI_screening_result: | |
| P_recall_I = sarcoI_screening_result['probability'] | |
| P_precision_I = sarcoI_advisory_result['probability'] if sarcoI_advisory_result else 0.0 | |
| sarcoI_advisory_threshold = self.thresholds.get('sarcoI', {}).get('advisory', 0.36) | |
| sarcoI_screening_threshold = self.thresholds.get('sarcoI', {}).get('screening', 0.23) | |
| if P_precision_I >= sarcoI_advisory_threshold: | |
| sarcoI_comprehensive_risk = "high" | |
| sarcoI_risk_reason = "advisory_model_high_risk" | |
| elif P_recall_I >= sarcoI_screening_threshold: | |
| sarcoI_comprehensive_risk = "medium" | |
| sarcoI_risk_reason = "screening_model_risk" | |
| else: | |
| sarcoI_comprehensive_risk = "low" | |
| sarcoI_risk_reason = "both_models_low_risk" | |
| results['sarcoI'] = { | |
| 'comprehensive_risk': sarcoI_comprehensive_risk, | |
| 'screening_probability': P_recall_I, | |
| 'advisory_probability': P_precision_I, | |
| 'risk_reason': sarcoI_risk_reason | |
| } | |
| # SarcoII 综合风险判定 | |
| if sarcoII_screening_result: | |
| P_recall_II = sarcoII_screening_result['probability'] | |
| P_precision_II = sarcoII_advisory_result['probability'] if sarcoII_advisory_result else 0.0 | |
| sarcoII_advisory_threshold = self.thresholds.get('sarcoII', {}).get('advisory', 0.52) | |
| sarcoII_screening_threshold = self.thresholds.get('sarcoII', {}).get('screening', 0.15) | |
| if P_precision_II >= sarcoII_advisory_threshold: | |
| sarcoII_comprehensive_risk = "high" | |
| sarcoII_risk_reason = "advisory_model_high_risk" | |
| elif P_recall_II >= sarcoII_screening_threshold: | |
| sarcoII_comprehensive_risk = "medium" | |
| sarcoII_risk_reason = "screening_model_risk" | |
| else: | |
| sarcoII_comprehensive_risk = "low" | |
| sarcoII_risk_reason = "both_models_low_risk" | |
| results['sarcoII'] = { | |
| 'comprehensive_risk': sarcoII_comprehensive_risk, | |
| 'screening_probability': P_recall_II, | |
| 'advisory_probability': P_precision_II, | |
| 'risk_reason': sarcoII_risk_reason | |
| } | |
| return results | |