TCM_Casual_Agent / app_bak.py
leonsimon23's picture
Rename app.py to app_bak.py
47a8b56 verified
import os
import networkx as nx
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
# --- 关键修改 1: 导入 DiscreteBayesianNetwork 而不是 BayesianNetwork ---
from pgmpy.models import DiscreteBayesianNetwork
from pgmpy.factors.discrete import TabularCPD
from pgmpy.inference import VariableElimination
import google.generativeai as genai
# --- Gemini API 配置 ---
try:
gemini_api_key = os.environ.get("GEMINI_API_KEY")
if not gemini_api_key:
raise ValueError("GEMINI_API_KEY secret not found in Hugging Face Space settings.")
genai.configure(api_key=gemini_api_key)
GEMINI_AVAILABLE = True
except (ValueError, ImportError) as e:
print(e)
GEMINI_AVAILABLE = False
# --- FastAPI 应用设置 ---
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# --- 因果推理模型定义 ---
# --- 关键修改 2: 使用 DiscreteBayesianNetwork 类进行实例化 ---
model = DiscreteBayesianNetwork([
# 症状 -> 证候
('头痛', '风热'), ('发热', '湿热'), ('乏力', '气虚'), ('恶心', '胃气虚'),
('咳嗽', '风寒'), ('腹痛', '脾虚'), ('胸闷', '心脉不通'), ('气短', '肺气虚'),
('便秘', '肝气郁结'), ('食欲不振', '脾胃虚弱'),
# 证候之间的关联
('肺气虚', '气虚'), ('脾虚', '气虚'),
('脾胃虚弱', '脾虚')
])
# 定义条件概率分布 (CPD)
# 症状 (根节点)
cpd_symptoms = {
'头痛': TabularCPD('头痛', 2, [[0.7], [0.3]]), '发热': TabularCPD('发热', 2, [[0.7], [0.3]]),
'乏力': TabularCPD('乏力', 2, [[0.5], [0.5]]), '恶心': TabularCPD('恶心', 2, [[0.8], [0.2]]),
'咳嗽': TabularCPD('咳嗽', 2, [[0.7], [0.3]]), '腹痛': TabularCPD('腹痛', 2, [[0.8], [0.2]]),
'胸闷': TabularCPD('胸闷', 2, [[0.7], [0.3]]), '气短': TabularCPD('气短', 2, [[0.6], [0.4]]),
'便秘': TabularCPD('便秘', 2, [[0.7], [0.3]]), '食欲不振': TabularCPD('食欲不振', 2, [[0.6], [0.4]])
}
# 证候 (中间节点)
cpd_syndromes = {
'风热': TabularCPD('风热', 2, [[0.9, 0.2], [0.1, 0.8]], evidence=['头痛'], evidence_card=[2]),
'湿热': TabularCPD('湿热', 2, [[0.9, 0.1], [0.1, 0.9]], evidence=['发热'], evidence_card=[2]),
'胃气虚': TabularCPD('胃气虚', 2, [[0.8, 0.3], [0.2, 0.7]], evidence=['恶心'], evidence_card=[2]),
'风寒': TabularCPD('风寒', 2, [[0.9, 0.2], [0.1, 0.8]], evidence=['咳嗽'], evidence_card=[2]),
'心脉不通': TabularCPD('心脉不通', 2, [[0.9, 0.1], [0.1, 0.9]], evidence=['胸闷'], evidence_card=[2]),
'肺气虚': TabularCPD('肺气虚', 2, [[0.8, 0.2], [0.2, 0.8]], evidence=['气短'], evidence_card=[2]),
'肝气郁结': TabularCPD('肝气郁结', 2, [[0.8, 0.3], [0.2, 0.7]], evidence=['便秘'], evidence_card=[2]),
'脾胃虚弱': TabularCPD('脾胃虚弱', 2, [[0.8, 0.2], [0.2, 0.8]], evidence=['食欲不振'], evidence_card=[2]),
'脾虚': TabularCPD('脾虚', 2, [[0.9, 0.6, 0.7, 0.1], [0.1, 0.4, 0.3, 0.9]], evidence=['腹痛', '脾胃虚弱'], evidence_card=[2, 2]),
'气虚': TabularCPD('气虚', 2, [[0.9, 0.8, 0.7, 0.6, 0.5, 0.3, 0.4, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.6, 0.9]],
evidence=['乏力', '肺气虚', '脾虚'], evidence_card=[2, 2, 2])
}
# 将所有定义的CPD添加到模型中
model.add_cpds(*cpd_symptoms.values(), *cpd_syndromes.values())
model.check_model()
# 初始化推理引擎
inference = VariableElimination(model)
class SymptomsInput(BaseModel):
symptoms: str
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/tcm_agent", response_class=HTMLResponse)
async def read_tcm_agent(request: Request):
return templates.TemplateResponse("tcm_agent.html", {"request": request})
@app.post("/analyze")
async def analyze_symptoms(symptoms_input: SymptomsInput):
try:
symptoms_list = [s.strip() for s in symptoms_input.symptoms.replace(',', ',').split(',') if s.strip()]
valid_symptoms = [s for s in symptoms_list if s in model.nodes()]
if not valid_symptoms:
return JSONResponse(status_code=400, content={"error": "输入的症状无法识别,请输入模型支持的症状。"})
evidence = {s: 1 for s in valid_symptoms}
syndrome_nodes = list(cpd_syndromes.keys())
# 找出所有可从证据节点到达的证候节点进行查询
query_nodes = []
for symptom in valid_symptoms:
for syndrome in syndrome_nodes:
if syndrome not in query_nodes and nx.has_path(model, symptom, syndrome):
query_nodes.append(syndrome)
results = {}
for node in query_nodes:
try:
prob = inference.query(variables=[node], evidence=evidence, show_progress=False)
results[node] = prob.values[1]
except Exception as e:
print(f"Could not infer for node {node}: {e}")
continue
if not results:
top_syndrome = "未知证候"
syndrome_result = "根据输入症状,无法推断出明确证候。请尝试其他症状组合。"
else:
top_syndrome = max(results, key=results.get)
syndrome_result = f"根据您的症状,最可能的证候是:**{top_syndrome}** (置信度: {results[top_syndrome]:.2%})"
# 病机和治法分析
pathogenesis_result = await get_gemini_response(
f"作为一位资深中医师,请根据以下中医证候 “{top_syndrome}” 分析其核心病机是什么。请用简洁、专业的语言回答,不超过50字。",
"无法生成病机分析。可能是由于API密钥未设置或网络问题。"
)
treatment_result = await get_gemini_response(
f"作为一位资深中医师,针对 “{top_syndrome}” 证候以及 “{pathogenesis_result}” 的病机,请提出对应的治疗法则。请用简洁、专业的语言回答,不超过50字。",
"无法生成治法建议。可能是由于API密钥未设置或网络问题。"
)
return {
"syndrome": syndrome_result,
"pathogenesis": pathogenesis_result,
"treatment": treatment_result
}
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"服务器内部错误: {str(e)}"})
async def get_gemini_response(prompt: str, error_message: str) -> str:
if not GEMINI_AVAILABLE:
return error_message
try:
model = genai.GenerativeModel('gemini-2.5-pro')
response = await model.generate_content_async(prompt)
return response.text
except Exception as e:
return f"{error_message} 错误详情: {str(e)}"