PinHsuan's picture
Update app.py
ffd09d4 verified
raw
history blame
7 kB
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
import os
import joblib
from model import DualStreamTransformer, ArcMarginProduct
css = """
.scroll-box {
height: 300px;
overflow-y: auto !important;
overflow-x: hidden !important;
display: block !important;
width: 100% !important;
max-width: 100% !important;
}
.scroll-box * {
max-width: 100% !important;
box-sizing: border-box !important;
}
.vertical-radio {
display: block !important;
width: 100% !important;
}
.vertical-radio .wrap {
display: flex !important;
flex-direction: column !important;
width: 100% !important;
min-width: 0 !important;
}
.vertical-radio .gradio-radio-item {
width: 100% !important;
white-space: normal !important;
word-break: break-all !important;
}
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FOLD = 5
MODEL_PATH = f"best_model_fold_{FOLD}.pt"
model = DualStreamTransformer(n_feat1=24, n_feat2=10, d_model=32).to(DEVICE)
metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
if os.path.exists(MODEL_PATH):
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint['model'])
metric_fc.load_state_dict(checkpoint['fc'])
model.eval()
metric_fc.eval()
scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl")
scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")
def analyze_and_predict(*all_answers):
ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
ccmq_ans = all_answers[:24]
osdi_ans_raw = all_answers[24:]
if any(a is None for a in all_answers):
raise gr.Error("請完整填寫所有問卷題目!")
x1_raw = np.array([[ccmq_map[a] for a in ccmq_ans]])
x2_raw = np.array([[osdi_map[a] for a in osdi_ans_raw[:10]]])
x1_scaled = scaler_ccmq.transform(x1_raw)
x2_scaled = scaler_osdi.transform(x2_raw)
sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE)
sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE)
with torch.no_grad():
feats = model(sx1, sx2)
logits = metric_fc.predict(feats)
probs = torch.softmax(logits, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
conf = probs[0, pred_idx].item()
plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
fig, ax = plt.subplots(figsize=(6, 4))
labels = ["健康/正常", "乾眼風險"]
sns.barplot(x=[conf if i == pred_idx else 1-conf for i in range(2)], y=labels, palette="viridis", ax=ax)
ax.set_title(f"AI 診斷信心度: {conf:.2%}")
table_data = [[f"題目 {i+1}", all_answers[i], "已記錄"] for i in range(10)]
res_label = "乾眼風險 (SJS/DES)" if pred_idx == 1 else "正常/健康"
return (
gr.update(visible=False),
gr.update(visible=True),
f"### 診斷結果:{res_label}",
{"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
table_data,
fig,
fig
)
def reset_system():
return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 36
with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo:
gr.Markdown("# 中醫AI診斷系統")
with gr.Column(visible=True) as stage_1:
with gr.Tabs() as survey_tabs:
with gr.Tab("CCMQ 體質評估", id=0):
with gr.Group(elem_classes="scroll-box"):
ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"]
all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}") for i, txt in enumerate(ccmq_labels)]
btn_next = gr.Button("下一步:填寫 OSDI", variant="primary")
with gr.Tab("OSDI 症狀評估", id=1):
with gr.Group(elem_classes="scroll-box"):
gr.Markdown("#### 在過去一週中,您是否出現下列症狀?")
o1 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="1. 眼睛對光敏感?")
o2 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="3. 眼睛疼痛?")
o3 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="4. 視線模糊?")
o4 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="5.視力減退?")
o5 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="6. 閱讀?")
o6 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="7. 夜間駕駛?")
o7 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="8. 操作電腦與提款機?")
o8 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="9. 觀看電視?")
o9 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="10. 刮風的狀況?")
o10 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="12. 區域使用空調?")
all_osdi = [o1, o2, o3, o4, o5, o6, o7, o8, o9, o10]
submit_btn = gr.Button("🚀 生成診斷報告", variant="primary")
with gr.Column(visible=False) as stage_2:
gr.Markdown("## 📊 AI 診斷分析報告")
with gr.Row():
res_table = gr.Dataframe(headers=["項目", "回答", "狀態"], interactive=False)
with gr.Column():
res_prob = gr.Label(label="預測機率")
res_title = gr.Markdown("### 診斷結果")
res_desc = gr.Markdown("分析中...")
plot_1 = gr.Plot()
plot_2 = gr.Plot()
finish_btn = gr.Button("結束並重新開始", size="lg", variant="secondary")
# 互動邏輯
all_inputs = all_ccmq + all_osdi
btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2])
finish_btn.click(fn=reset_system, outputs=[stage_1, stage_2, survey_tabs] + all_inputs)
if __name__ == "__main__":
demo.launch()