Spaces:
Sleeping
Sleeping
File size: 6,999 Bytes
6105886 562c23b 6105886 445065e 562c23b ffd09d4 6105886 ffd09d4 6105886 562c23b 95626bd 6105886 ffd09d4 6105886 562c23b ffd09d4 6105886 562c23b 6105886 ffd09d4 6105886 ffd09d4 6105886 ffd09d4 6105886 95626bd 6105886 ffd09d4 6105886 ffd09d4 6105886 4bd40e6 6105886 ffd09d4 6105886 ffd09d4 6105886 ffd09d4 6105886 ffd09d4 6105886 ffd09d4 6105886 ffd09d4 6105886 ffd09d4 6105886 ffd09d4 6105886 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
import os
import joblib
from model import DualStreamTransformer, ArcMarginProduct
css = """
.scroll-box {
height: 300px;
overflow-y: auto !important;
overflow-x: hidden !important;
display: block !important;
width: 100% !important;
max-width: 100% !important;
}
.scroll-box * {
max-width: 100% !important;
box-sizing: border-box !important;
}
.vertical-radio {
display: block !important;
width: 100% !important;
}
.vertical-radio .wrap {
display: flex !important;
flex-direction: column !important;
width: 100% !important;
min-width: 0 !important;
}
.vertical-radio .gradio-radio-item {
width: 100% !important;
white-space: normal !important;
word-break: break-all !important;
}
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FOLD = 5
MODEL_PATH = f"best_model_fold_{FOLD}.pt"
model = DualStreamTransformer(n_feat1=24, n_feat2=10, d_model=32).to(DEVICE)
metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
if os.path.exists(MODEL_PATH):
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint['model'])
metric_fc.load_state_dict(checkpoint['fc'])
model.eval()
metric_fc.eval()
scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl")
scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")
def analyze_and_predict(*all_answers):
ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
ccmq_ans = all_answers[:24]
osdi_ans_raw = all_answers[24:]
if any(a is None for a in all_answers):
raise gr.Error("請完整填寫所有問卷題目!")
x1_raw = np.array([[ccmq_map[a] for a in ccmq_ans]])
x2_raw = np.array([[osdi_map[a] for a in osdi_ans_raw[:10]]])
x1_scaled = scaler_ccmq.transform(x1_raw)
x2_scaled = scaler_osdi.transform(x2_raw)
sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE)
sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE)
with torch.no_grad():
feats = model(sx1, sx2)
logits = metric_fc.predict(feats)
probs = torch.softmax(logits, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
conf = probs[0, pred_idx].item()
plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
fig, ax = plt.subplots(figsize=(6, 4))
labels = ["健康/正常", "乾眼風險"]
sns.barplot(x=[conf if i == pred_idx else 1-conf for i in range(2)], y=labels, palette="viridis", ax=ax)
ax.set_title(f"AI 診斷信心度: {conf:.2%}")
table_data = [[f"題目 {i+1}", all_answers[i], "已記錄"] for i in range(10)]
res_label = "乾眼風險 (SJS/DES)" if pred_idx == 1 else "正常/健康"
return (
gr.update(visible=False),
gr.update(visible=True),
f"### 診斷結果:{res_label}",
{"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
table_data,
fig,
fig
)
def reset_system():
return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 36
with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo:
gr.Markdown("# 中醫AI診斷系統")
with gr.Column(visible=True) as stage_1:
with gr.Tabs() as survey_tabs:
with gr.Tab("CCMQ 體質評估", id=0):
with gr.Group(elem_classes="scroll-box"):
ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"]
all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}") for i, txt in enumerate(ccmq_labels)]
btn_next = gr.Button("下一步:填寫 OSDI", variant="primary")
with gr.Tab("OSDI 症狀評估", id=1):
with gr.Group(elem_classes="scroll-box"):
gr.Markdown("#### 在過去一週中,您是否出現下列症狀?")
o1 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="1. 眼睛對光敏感?")
o2 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="3. 眼睛疼痛?")
o3 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="4. 視線模糊?")
o4 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="5.視力減退?")
o5 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="6. 閱讀?")
o6 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="7. 夜間駕駛?")
o7 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="8. 操作電腦與提款機?")
o8 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="9. 觀看電視?")
o9 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="10. 刮風的狀況?")
o10 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="12. 區域使用空調?")
all_osdi = [o1, o2, o3, o4, o5, o6, o7, o8, o9, o10]
submit_btn = gr.Button("🚀 生成診斷報告", variant="primary")
with gr.Column(visible=False) as stage_2:
gr.Markdown("## 📊 AI 診斷分析報告")
with gr.Row():
res_table = gr.Dataframe(headers=["項目", "回答", "狀態"], interactive=False)
with gr.Column():
res_prob = gr.Label(label="預測機率")
res_title = gr.Markdown("### 診斷結果")
res_desc = gr.Markdown("分析中...")
plot_1 = gr.Plot()
plot_2 = gr.Plot()
finish_btn = gr.Button("結束並重新開始", size="lg", variant="secondary")
# 互動邏輯
all_inputs = all_ccmq + all_osdi
btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2])
finish_btn.click(fn=reset_system, outputs=[stage_1, stage_2, survey_tabs] + all_inputs)
if __name__ == "__main__":
demo.launch() |