Spaces:
Sleeping
Sleeping
| import os | |
| import networkx as nx | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.templating import Jinja2Templates | |
| from pydantic import BaseModel | |
| # --- 关键修改 1: 导入 DiscreteBayesianNetwork 而不是 BayesianNetwork --- | |
| from pgmpy.models import DiscreteBayesianNetwork | |
| from pgmpy.factors.discrete import TabularCPD | |
| from pgmpy.inference import VariableElimination | |
| import google.generativeai as genai | |
| # --- Gemini API 配置 --- | |
| try: | |
| gemini_api_key = os.environ.get("GEMINI_API_KEY") | |
| if not gemini_api_key: | |
| raise ValueError("GEMINI_API_KEY secret not found in Hugging Face Space settings.") | |
| genai.configure(api_key=gemini_api_key) | |
| GEMINI_AVAILABLE = True | |
| except (ValueError, ImportError) as e: | |
| print(e) | |
| GEMINI_AVAILABLE = False | |
| # --- FastAPI 应用设置 --- | |
| app = FastAPI() | |
| templates = Jinja2Templates(directory="templates") | |
| # --- 因果推理模型定义 --- | |
| # --- 关键修改 2: 使用 DiscreteBayesianNetwork 类进行实例化 --- | |
| model = DiscreteBayesianNetwork([ | |
| # 症状 -> 证候 | |
| ('头痛', '风热'), ('发热', '湿热'), ('乏力', '气虚'), ('恶心', '胃气虚'), | |
| ('咳嗽', '风寒'), ('腹痛', '脾虚'), ('胸闷', '心脉不通'), ('气短', '肺气虚'), | |
| ('便秘', '肝气郁结'), ('食欲不振', '脾胃虚弱'), | |
| # 证候之间的关联 | |
| ('肺气虚', '气虚'), ('脾虚', '气虚'), | |
| ('脾胃虚弱', '脾虚') | |
| ]) | |
| # 定义条件概率分布 (CPD) | |
| # 症状 (根节点) | |
| cpd_symptoms = { | |
| '头痛': TabularCPD('头痛', 2, [[0.7], [0.3]]), '发热': TabularCPD('发热', 2, [[0.7], [0.3]]), | |
| '乏力': TabularCPD('乏力', 2, [[0.5], [0.5]]), '恶心': TabularCPD('恶心', 2, [[0.8], [0.2]]), | |
| '咳嗽': TabularCPD('咳嗽', 2, [[0.7], [0.3]]), '腹痛': TabularCPD('腹痛', 2, [[0.8], [0.2]]), | |
| '胸闷': TabularCPD('胸闷', 2, [[0.7], [0.3]]), '气短': TabularCPD('气短', 2, [[0.6], [0.4]]), | |
| '便秘': TabularCPD('便秘', 2, [[0.7], [0.3]]), '食欲不振': TabularCPD('食欲不振', 2, [[0.6], [0.4]]) | |
| } | |
| # 证候 (中间节点) | |
| cpd_syndromes = { | |
| '风热': TabularCPD('风热', 2, [[0.9, 0.2], [0.1, 0.8]], evidence=['头痛'], evidence_card=[2]), | |
| '湿热': TabularCPD('湿热', 2, [[0.9, 0.1], [0.1, 0.9]], evidence=['发热'], evidence_card=[2]), | |
| '胃气虚': TabularCPD('胃气虚', 2, [[0.8, 0.3], [0.2, 0.7]], evidence=['恶心'], evidence_card=[2]), | |
| '风寒': TabularCPD('风寒', 2, [[0.9, 0.2], [0.1, 0.8]], evidence=['咳嗽'], evidence_card=[2]), | |
| '心脉不通': TabularCPD('心脉不通', 2, [[0.9, 0.1], [0.1, 0.9]], evidence=['胸闷'], evidence_card=[2]), | |
| '肺气虚': TabularCPD('肺气虚', 2, [[0.8, 0.2], [0.2, 0.8]], evidence=['气短'], evidence_card=[2]), | |
| '肝气郁结': TabularCPD('肝气郁结', 2, [[0.8, 0.3], [0.2, 0.7]], evidence=['便秘'], evidence_card=[2]), | |
| '脾胃虚弱': TabularCPD('脾胃虚弱', 2, [[0.8, 0.2], [0.2, 0.8]], evidence=['食欲不振'], evidence_card=[2]), | |
| '脾虚': TabularCPD('脾虚', 2, [[0.9, 0.6, 0.7, 0.1], [0.1, 0.4, 0.3, 0.9]], evidence=['腹痛', '脾胃虚弱'], evidence_card=[2, 2]), | |
| '气虚': TabularCPD('气虚', 2, [[0.9, 0.8, 0.7, 0.6, 0.5, 0.3, 0.4, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.6, 0.9]], | |
| evidence=['乏力', '肺气虚', '脾虚'], evidence_card=[2, 2, 2]) | |
| } | |
| # 将所有定义的CPD添加到模型中 | |
| model.add_cpds(*cpd_symptoms.values(), *cpd_syndromes.values()) | |
| model.check_model() | |
| # 初始化推理引擎 | |
| inference = VariableElimination(model) | |
| class SymptomsInput(BaseModel): | |
| symptoms: str | |
| async def read_root(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def read_tcm_agent(request: Request): | |
| return templates.TemplateResponse("tcm_agent.html", {"request": request}) | |
| async def analyze_symptoms(symptoms_input: SymptomsInput): | |
| try: | |
| symptoms_list = [s.strip() for s in symptoms_input.symptoms.replace(',', ',').split(',') if s.strip()] | |
| valid_symptoms = [s for s in symptoms_list if s in model.nodes()] | |
| if not valid_symptoms: | |
| return JSONResponse(status_code=400, content={"error": "输入的症状无法识别,请输入模型支持的症状。"}) | |
| evidence = {s: 1 for s in valid_symptoms} | |
| syndrome_nodes = list(cpd_syndromes.keys()) | |
| # 找出所有可从证据节点到达的证候节点进行查询 | |
| query_nodes = [] | |
| for symptom in valid_symptoms: | |
| for syndrome in syndrome_nodes: | |
| if syndrome not in query_nodes and nx.has_path(model, symptom, syndrome): | |
| query_nodes.append(syndrome) | |
| results = {} | |
| for node in query_nodes: | |
| try: | |
| prob = inference.query(variables=[node], evidence=evidence, show_progress=False) | |
| results[node] = prob.values[1] | |
| except Exception as e: | |
| print(f"Could not infer for node {node}: {e}") | |
| continue | |
| if not results: | |
| top_syndrome = "未知证候" | |
| syndrome_result = "根据输入症状,无法推断出明确证候。请尝试其他症状组合。" | |
| else: | |
| top_syndrome = max(results, key=results.get) | |
| syndrome_result = f"根据您的症状,最可能的证候是:**{top_syndrome}** (置信度: {results[top_syndrome]:.2%})" | |
| # 病机和治法分析 | |
| pathogenesis_result = await get_gemini_response( | |
| f"作为一位资深中医师,请根据以下中医证候 “{top_syndrome}” 分析其核心病机是什么。请用简洁、专业的语言回答,不超过50字。", | |
| "无法生成病机分析。可能是由于API密钥未设置或网络问题。" | |
| ) | |
| treatment_result = await get_gemini_response( | |
| f"作为一位资深中医师,针对 “{top_syndrome}” 证候以及 “{pathogenesis_result}” 的病机,请提出对应的治疗法则。请用简洁、专业的语言回答,不超过50字。", | |
| "无法生成治法建议。可能是由于API密钥未设置或网络问题。" | |
| ) | |
| return { | |
| "syndrome": syndrome_result, | |
| "pathogenesis": pathogenesis_result, | |
| "treatment": treatment_result | |
| } | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": f"服务器内部错误: {str(e)}"}) | |
| async def get_gemini_response(prompt: str, error_message: str) -> str: | |
| if not GEMINI_AVAILABLE: | |
| return error_message | |
| try: | |
| model = genai.GenerativeModel('gemini-2.5-pro') | |
| response = await model.generate_content_async(prompt) | |
| return response.text | |
| except Exception as e: | |
| return f"{error_message} 错误详情: {str(e)}" |