Spaces:
Sleeping
Sleeping
| """ | |
| SarcoAdvisor FastAPI主应用 | |
| 肌少症风险评估和个性化建议系统 | |
| """ | |
| import logging | |
| import time | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException, Request, BackgroundTasks | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| # 导入自定义模块 | |
| from schemas.user_input import ( | |
| UserInput, ScreeningRequest, AdvisoryRequest, | |
| ScreeningResponse, AdvisoryResponse, ErrorResponse | |
| ) | |
| from models.screening_models import screening_service | |
| from models.advisory_models import advisory_service | |
| from utils.model_loader import model_manager | |
| # 配置日志 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # 应用生命周期管理 | |
| async def lifespan(app: FastAPI): | |
| # 启动时加载模型 | |
| logger.info("🚀 启动SarcoAdvisor Web应用...") | |
| try: | |
| logger.info("📊 加载机器学习模型...") | |
| model_manager.load_all_models() | |
| logger.info("✅ 模型加载完成") | |
| logger.info("🚀 初始化DiCE解释器 (终极精度模式 - 追求最高质量,无任何时间限制)...") | |
| advisory_service.initialize_dice() | |
| logger.info("✅ DiCE解释器初始化完成 - 已启用终极精度模式 (500个候选,25倍多样性权重)") | |
| logger.info("🎯 SarcoAdvisor服务就绪!") | |
| except Exception as e: | |
| logger.error(f"❌ 应用启动失败: {str(e)}") | |
| raise | |
| yield | |
| # 关闭时清理资源 | |
| logger.info("🔄 关闭SarcoAdvisor服务...") | |
| # 创建FastAPI应用 | |
| app = FastAPI( | |
| title="SarcoAdvisor API", | |
| description="肌少症风险评估和个性化建议系统", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| lifespan=lifespan | |
| ) | |
| # 确保模型在应用启动时加载(备用方案) | |
| async def startup_event(): | |
| """应用启动事件 - 确保模型加载""" | |
| try: | |
| # 检查模型是否已加载 | |
| if not model_manager.advisory_models: | |
| logger.warning("⚠️ 检测到模型未加载,强制加载...") | |
| model_manager.load_all_models() | |
| logger.info("✅ 备用模型加载完成") | |
| # 检查DiCE是否已初始化 | |
| if not hasattr(advisory_service, 'dice_explainers') or not advisory_service.dice_explainers: | |
| logger.warning("⚠️ 检测到DiCE未初始化,强制初始化...") | |
| advisory_service.initialize_dice() | |
| logger.info("✅ 备用DiCE初始化完成") | |
| logger.info(f"🎯 模型状态检查: 筛查模型={list(model_manager.screening_models.keys())}, 建议模型={list(model_manager.advisory_models.keys())}") | |
| except Exception as e: | |
| logger.error(f"❌ 备用启动过程失败: {str(e)}") | |
| # 不抛出异常,让应用继续运行 | |
| # 添加CORS中间件 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # 生产环境中应该限制具体域名 | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 挂载静态文件 | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # 模板配置 | |
| templates = Jinja2Templates(directory="templates") | |
| # 全局异常处理 | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| logger.error(f"全局异常: {str(exc)}") | |
| return JSONResponse( | |
| status_code=500, | |
| content=ErrorResponse( | |
| error="Internal Server Error", | |
| detail=str(exc), | |
| timestamp=str(time.time()) | |
| ).model_dump() | |
| ) | |
| # 根路径 - 返回快速评估页面 | |
| async def home(request: Request): | |
| """快速评估主页面""" | |
| return templates.TemplateResponse("quick_assessment.html", {"request": request}) | |
| # 原完整评估页面 | |
| async def full_assessment_page(request: Request): | |
| """完整评估页面""" | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| # 统一评估页面 (新的双模型评估) | |
| async def unified_assessment_page(request: Request): | |
| """统一评估页面 - 同时生成SarcoI和SarcoII建议""" | |
| return templates.TemplateResponse("unified_assessment.html", {"request": request}) | |
| # 字体测试页面 | |
| async def font_test_page(request: Request): | |
| """字体样式测试页面""" | |
| return templates.TemplateResponse("font_test.html", {"request": request}) | |
| # 健康检查接口 | |
| async def health_check(): | |
| """健康检查""" | |
| try: | |
| # 检查模型是否已加载 | |
| screening_ready = bool(model_manager.screening_models) | |
| advisory_ready = bool(model_manager.advisory_models) | |
| return { | |
| "status": "healthy" if (screening_ready and advisory_ready) else "partial", | |
| "timestamp": time.time(), | |
| "services": { | |
| "screening_models": screening_ready, | |
| "advisory_models": advisory_ready, | |
| "dice_explainers": bool(advisory_service.dice_explainers) | |
| }, | |
| "version": "1.0.0" | |
| } | |
| except Exception as e: | |
| logger.error(f"健康检查失败: {str(e)}") | |
| return JSONResponse( | |
| status_code=503, | |
| content={"status": "unhealthy", "error": str(e)} | |
| ) | |
| # 诊断接口 - 用于调试性能问题 | |
| async def diagnostic_check(request: dict): | |
| """诊断检查 - 帮助识别性能瓶颈""" | |
| start_time = time.time() | |
| try: | |
| logger.info("开始诊断检查...") | |
| # 1. 检查数据验证耗时 | |
| validation_start = time.time() | |
| try: | |
| user_input = UserInput(**request) | |
| validation_time = time.time() - validation_start | |
| except Exception as e: | |
| return {"error": f"数据验证失败: {str(e)}", "step": "validation"} | |
| # 2. 检查筛查模型耗时 | |
| screening_start = time.time() | |
| try: | |
| screening_result = await screening_service.screening_assessment(user_input, ['sarcoI', 'sarcoII']) | |
| screening_time = time.time() - screening_start | |
| except Exception as e: | |
| return {"error": f"筛查模型失败: {str(e)}", "step": "screening"} | |
| # 3. 检查建议模型基础预测耗时(不包括DiCE) | |
| advisory_start = time.time() | |
| try: | |
| # 这里我们只做基础预测,不包括DiCE | |
| user_dict = user_input.model_dump() | |
| sarcoI_result = model_manager.predict_advisory(user_dict, 'sarcoI') | |
| sarcoII_result = model_manager.predict_advisory(user_dict, 'sarcoII') | |
| advisory_time = time.time() - advisory_start | |
| except Exception as e: | |
| return {"error": f"建议模型预测失败: {str(e)}", "step": "advisory_prediction"} | |
| total_time = time.time() - start_time | |
| return { | |
| "status": "success", | |
| "timings": { | |
| "validation": f"{validation_time:.3f}s", | |
| "screening": f"{screening_time:.3f}s", | |
| "advisory_prediction": f"{advisory_time:.3f}s", | |
| "total": f"{total_time:.3f}s" | |
| }, | |
| "note": "DiCE analysis not included in this diagnostic" | |
| } | |
| except Exception as e: | |
| total_time = time.time() - start_time | |
| logger.error(f"诊断检查失败: {str(e)}") | |
| return { | |
| "error": str(e), | |
| "total_time": f"{total_time:.3f}s" | |
| } | |
| # 筛查接口 | |
| async def screening_assessment(request: ScreeningRequest): | |
| """ | |
| 风险筛查评估 | |
| 使用高召回率模型进行快速风险筛查 | |
| """ | |
| try: | |
| logger.info(f"收到筛查请求: 模型={request.models}") | |
| # 验证用户数据 | |
| user_data = request.user_data | |
| # 执行筛查 | |
| result = await screening_service.screening_assessment( | |
| user_data=user_data, | |
| models=request.models | |
| ) | |
| logger.info(f"筛查完成: 综合风险={result.overall_risk}, 耗时={result.processing_time:.2f}s") | |
| return result | |
| except Exception as e: | |
| logger.error(f"筛查评估失败: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"筛查评估失败: {str(e)}" | |
| ) | |
| # 建议生成接口 | |
| async def advisory_recommendations(request: AdvisoryRequest): | |
| """ | |
| 🚀 终极精度个性化建议生成 | |
| 使用完整数据集 + 高精确率模型 + DiCE反事实解释生成最精准的个性化建议 | |
| 配置:500个反事实候选,25倍多样性权重,无任何时间限制 | |
| 注意:追求终极精度,响应时间可能很长,请耐心等待最佳结果 | |
| """ | |
| try: | |
| logger.info(f"收到建议请求: 风险类型={request.risk_types}, 建议数量={request.num_recommendations}") | |
| # 验证用户数据 | |
| user_data = request.user_data | |
| # 生成建议 | |
| result = await advisory_service.generate_recommendations( | |
| user_data=user_data, | |
| risk_types=request.risk_types, | |
| num_recommendations=request.num_recommendations, | |
| language=request.language | |
| ) | |
| logger.info(f"建议生成完成: SarcoI={len(result.sarcoI_recommendations)}, " | |
| f"SarcoII={len(result.sarcoII_recommendations)}, " | |
| f"fallback={result.fallback_used}, 耗时={result.processing_time:.2f}s") | |
| return result | |
| except Exception as e: | |
| logger.error(f"建议生成失败: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"建议生成失败: {str(e)}" | |
| ) | |
| # 完整评估接口 (筛查 + 建议) | |
| async def full_assessment(request: dict, background_tasks: BackgroundTasks): | |
| """ | |
| 完整评估流程 | |
| 先进行筛查,对高风险用户生成个性化建议 | |
| """ | |
| try: | |
| start_time = time.time() | |
| logger.info("开始完整评估流程") | |
| # 解析请求参数 | |
| user_data = UserInput(**request) | |
| language = request.get('language', 'zh') # 获取语言参数,默认中文 | |
| logger.info(f"完整评估语言设置: {language}") | |
| # 第一步: 筛查评估 | |
| screening_request = ScreeningRequest( | |
| user_data=user_data, | |
| models=["sarcoI", "sarcoII"] | |
| ) | |
| screening_result = await screening_assessment(screening_request) | |
| # 第二步: 建议模型结果已包含在筛查结果中 | |
| advisory_result = None | |
| # 生成个性化建议(DiCE)- 改进逻辑 | |
| # 1. 优先检查建议模型的风险评估结果 | |
| # 2. 如果建议模型有结果,则生成建议(包括低风险的维持性建议) | |
| # 3. 否则基于筛查模型结果决定 | |
| # 智能建议生成:SarcoII是SarcoI的严重版,SarcoII高危则SarcoI必然高危 | |
| risk_types = [] | |
| should_generate_advice = False | |
| primary_risk_type = None | |
| # 检查是否有完整的活动数据用于建议生成 | |
| activity_fields = ['PAQ605', 'PAQ620', 'PAQ635', 'PAQ650', 'PAQ665', 'PAD680'] | |
| has_complete_data = all(field in user_data for field in activity_fields) | |
| # 1. 首先检查SarcoII(更严重的肌少症) | |
| logger.info(f"🔍 调试风险判断 - SarcoII筛查风险: {screening_result.sarcoII_risk}") | |
| logger.info(f"🔍 调试风险判断 - SarcoII建议风险: {getattr(screening_result, 'sarcoII_advisory_risk', 'None')}") | |
| logger.info(f"🔍 调试风险判断 - SarcoI筛查风险: {getattr(screening_result, 'sarcoI_risk', 'None')}") | |
| logger.info(f"🔍 调试风险判断 - SarcoI建议风险: {getattr(screening_result, 'sarcoI_advisory_risk', 'None')}") | |
| sarcoII_has_risk = False | |
| if hasattr(screening_result, 'sarcoII_advisory_risk') and screening_result.sarcoII_advisory_risk: | |
| if screening_result.sarcoII_advisory_risk in ["medium", "high"]: | |
| sarcoII_has_risk = True | |
| logger.info(f"SarcoII建议模型显示风险: {screening_result.sarcoII_advisory_risk}") | |
| elif screening_result.sarcoII_risk in ["medium", "high"]: | |
| sarcoII_has_risk = True | |
| logger.info(f"SarcoII筛查模型显示风险: {screening_result.sarcoII_risk}") | |
| # 2. 检查SarcoI | |
| sarcoI_has_risk = False | |
| if hasattr(screening_result, 'sarcoI_advisory_risk') and screening_result.sarcoI_advisory_risk: | |
| if screening_result.sarcoI_advisory_risk in ["medium", "high"]: | |
| sarcoI_has_risk = True | |
| logger.info(f"SarcoI建议模型显示风险: {screening_result.sarcoI_advisory_risk}") | |
| elif hasattr(screening_result, 'sarcoI_risk') and screening_result.sarcoI_risk and screening_result.sarcoI_risk in ["medium", "high"]: | |
| sarcoI_has_risk = True | |
| logger.info(f"SarcoI筛查模型显示风险: {screening_result.sarcoI_risk}") | |
| logger.info(f"🔍 风险判断结果 - SarcoII有风险: {sarcoII_has_risk}, SarcoI有风险: {sarcoI_has_risk}") | |
| # 3. 根据风险情况决定建议生成策略 | |
| # 修正逻辑:只要筛查模型显示风险,就应该生成建议 | |
| if sarcoII_has_risk: | |
| # SarcoII高危:生成SarcoII建议(包含SarcoI改善) | |
| risk_types = ["sarcoII"] | |
| should_generate_advice = True | |
| primary_risk_type = "sarcoII" | |
| logger.info("🎯 检测到SarcoII风险,生成SarcoII建议(SarcoII高危意味着SarcoI也高危)") | |
| elif sarcoI_has_risk: | |
| # 只有SarcoI高危:生成SarcoI建议 | |
| risk_types = ["sarcoI"] | |
| should_generate_advice = True | |
| primary_risk_type = "sarcoI" | |
| logger.info("🎯 检测到SarcoI风险,生成SarcoI建议") | |
| elif screening_result.sarcoII_risk in ["medium", "high"]: | |
| # 修正:即使建议模型显示低风险,但筛查模型显示风险时,仍应生成建议 | |
| risk_types = ["sarcoII"] | |
| should_generate_advice = True | |
| primary_risk_type = "sarcoII" | |
| logger.info(f"🎯 SarcoII筛查模型显示{screening_result.sarcoII_risk}风险,生成建议") | |
| elif hasattr(screening_result, 'sarcoI_risk') and screening_result.sarcoI_risk in ["medium", "high"]: | |
| # 修正:即使建议模型显示低风险,但筛查模型显示风险时,仍应生成建议 | |
| risk_types = ["sarcoI"] | |
| should_generate_advice = True | |
| primary_risk_type = "sarcoI" | |
| logger.info(f"🎯 SarcoI筛查模型显示{screening_result.sarcoI_risk}风险,生成建议") | |
| elif has_complete_data: | |
| # 都是低风险但数据完整:生成维持性建议 | |
| # 选择风险概率更高的模型生成维持性建议 | |
| if screening_result.sarcoII_probability > getattr(screening_result, 'sarcoI_probability', 0): | |
| risk_types = ["sarcoII"] | |
| primary_risk_type = "sarcoII" | |
| logger.info(f"🔄 低风险但数据完整,为SarcoII生成维持性建议(概率: {screening_result.sarcoII_probability:.3f})") | |
| else: | |
| risk_types = ["sarcoI"] | |
| primary_risk_type = "sarcoI" | |
| sarcoI_prob = getattr(screening_result, 'sarcoI_probability', 0) | |
| logger.info(f"🔄 低风险但数据完整,为SarcoI生成维持性建议(概率: {sarcoI_prob:.3f})") | |
| should_generate_advice = True | |
| else: | |
| logger.info("⚠️ 无风险且数据不完整,不生成建议") | |
| if should_generate_advice and risk_types: | |
| logger.info(f"为以下模型生成建议: {risk_types}") | |
| advisory_request = AdvisoryRequest( | |
| user_data=user_data, | |
| risk_types=risk_types, | |
| num_recommendations=5, # 🚀 增加到5个建议,获得更多样化的建议 | |
| language=language # 🌍 传递语言参数 | |
| ) | |
| advisory_result = await advisory_recommendations(advisory_request) | |
| else: | |
| logger.info("不满足建议生成条件,跳过DiCE建议") | |
| # 计算新的综合风险评估 | |
| comprehensive_risk = None | |
| try: | |
| # 准备筛查和建议模型结果 | |
| sarcoI_screening = { | |
| 'probability': screening_result.sarcoI_probability, | |
| 'risk_level': screening_result.sarcoI_risk.value | |
| } | |
| sarcoI_advisory = None | |
| if screening_result.sarcoI_advisory_probability is not None: | |
| sarcoI_advisory = { | |
| 'probability': screening_result.sarcoI_advisory_probability, | |
| 'risk_level': screening_result.sarcoI_advisory_risk.value | |
| } | |
| sarcoII_screening = { | |
| 'probability': screening_result.sarcoII_probability, | |
| 'risk_level': screening_result.sarcoII_risk.value | |
| } | |
| sarcoII_advisory = None | |
| if screening_result.sarcoII_advisory_probability is not None: | |
| sarcoII_advisory = { | |
| 'probability': screening_result.sarcoII_advisory_probability, | |
| 'risk_level': screening_result.sarcoII_advisory_risk.value | |
| } | |
| # 计算综合风险 | |
| comprehensive_risk = model_manager.get_comprehensive_risk( | |
| sarcoI_screening_result=sarcoI_screening, | |
| sarcoI_advisory_result=sarcoI_advisory, | |
| sarcoII_screening_result=sarcoII_screening, | |
| sarcoII_advisory_result=sarcoII_advisory | |
| ) | |
| logger.info(f"综合风险评估完成: {comprehensive_risk}") | |
| except Exception as e: | |
| logger.error(f"综合风险计算失败: {str(e)}") | |
| # 获取风险解释 | |
| risk_explanation = screening_service.get_risk_explanation(screening_result) | |
| total_time = time.time() - start_time | |
| response = { | |
| "screening": screening_result.model_dump(), | |
| "advisory": advisory_result.model_dump() if advisory_result else None, | |
| "comprehensive_risk": comprehensive_risk, # 新增综合风险评估 | |
| "risk_explanation": risk_explanation, | |
| "needs_advisory": bool(advisory_result), | |
| "total_processing_time": total_time, | |
| "timestamp": time.time() | |
| } | |
| logger.info(f"完整评估完成: 风险={screening_result.overall_risk}, " | |
| f"生成建议={'是' if advisory_result else '否'}, 总耗时={total_time:.2f}s") | |
| return response | |
| except Exception as e: | |
| logger.error(f"完整评估失败: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"完整评估失败: {str(e)}" | |
| ) | |
| # 模型信息接口 | |
| async def get_model_info(): | |
| """获取模型信息""" | |
| try: | |
| return { | |
| "screening_models": { | |
| "sarcoI": { | |
| "type": "RandomForest", | |
| "purpose": "高召回率筛查", | |
| "threshold": model_manager.thresholds.get('sarcoI', {}).get('screening', 0.5), | |
| "performance": { | |
| "recall": 0.9114, | |
| "precision": 0.4305, | |
| "model_path": "/Users/ning/Desktop/idea/代码forSarcoAdvisor/3.建模/SarcoI_results/randomforest_model.pkl", | |
| "features": 3 | |
| } | |
| }, | |
| "sarcoII": { | |
| "type": "CatBoost", | |
| "purpose": "高召回率筛查", | |
| "threshold": model_manager.thresholds.get('sarcoII', {}).get('screening', 0.5), | |
| "performance": { | |
| "precision": 0.2548, | |
| "recall": 0.8983, | |
| "model_path": "/Users/ning/Desktop/idea/代码forSarcoAdvisor/3.建模/SarcoII_results/catboost_model.cbm", | |
| "features": 4 | |
| } | |
| } | |
| }, | |
| "advisory_models": { | |
| "sarcoI": { | |
| "type": "CatBoost", | |
| "purpose": "高精确率建议 + DiCE", | |
| "threshold": model_manager.thresholds.get('sarcoI', {}).get('advisory', 0.36), | |
| "dice_features": 5 | |
| }, | |
| "sarcoII": { | |
| "type": "RandomForest", | |
| "purpose": "高精确率建议 + DiCE", | |
| "threshold": model_manager.thresholds.get('sarcoII', {}).get('advisory', 0.52), | |
| "dice_features": 6 | |
| } | |
| }, | |
| "features": { | |
| "sarcoI": ["body_mass_index", "race_ethnicity", "WWI", "age_years", | |
| "Activity_Sedentary_Ratio", "Total_Moderate_Minutes_week", "Vigorous_MET_Ratio"], | |
| "sarcoII": ["body_mass_index", "race_ethnicity", "age_years", | |
| "Activity_Sedentary_Ratio", "Activity_Diversity_Index", "WWI", | |
| "Vigorous_MET_Ratio", "sedentary_minutes"] | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"获取模型信息失败: {str(e)}") | |
| raise HTTPException(status_code=500, detail="获取模型信息失败") | |
| # 快速评估接口 (基于4个共同特征) | |
| async def quick_assessment(user_data: UserInput): | |
| """ | |
| 快速评估接口 | |
| 基于4个模型的共同特征 (age_years, race_ethnicity, body_mass_index, WWI) | |
| 提供所有4个模型的初步评估结果 | |
| """ | |
| try: | |
| start_time = time.time() | |
| logger.info("开始快速评估流程") | |
| # 纯筛查模式:只运行筛查模型,不运行建议模型 | |
| screening_request = ScreeningRequest( | |
| user_data=user_data, | |
| models=["sarcoI", "sarcoII"] | |
| ) | |
| # 调用筛查服务,明确指定不包含建议模型 | |
| screening_result = await screening_service.screening_assessment( | |
| user_data=user_data, | |
| models=["sarcoI", "sarcoII"], | |
| include_advisory=False # 关键:快速评估不运行建议模型 | |
| ) | |
| # 获取风险解释 | |
| risk_explanation = screening_service.get_risk_explanation(screening_result) | |
| # 构建快速评估结果 | |
| result = { | |
| "screening": screening_result.model_dump(), | |
| "advisory": None, # 快速评估不提供建议模型结果 | |
| "risk_explanation": risk_explanation, | |
| "needs_advisory": False, # 快速评估阶段不需要DiCE建议 | |
| "assessment_type": "quick", | |
| "common_features_used": ["age_years", "race_ethnicity", "body_mass_index", "WWI"], | |
| "total_processing_time": 0, # 将在下面计算 | |
| "timestamp": time.time() | |
| } | |
| total_time = time.time() - start_time | |
| result["total_processing_time"] = total_time | |
| logger.info(f"快速评估完成,总耗时={total_time:.2f}s") | |
| return result | |
| except Exception as e: | |
| logger.error(f"快速评估失败: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"快速评估失败: {str(e)}" | |
| ) | |
| # 应用启动函数 | |
| def start_server(): | |
| """启动服务器""" | |
| import os | |
| port = int(os.environ.get("PORT", 8001)) # 默认使用8001端口 | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=False, # 云端部署时关闭reload | |
| log_level="info" | |
| ) | |
| if __name__ == "__main__": | |
| start_server() |