Spaces:
Sleeping
Sleeping
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import gradio as gr | |
| import os | |
| import joblib | |
| from model import DualStreamTransformer, ArcMarginProduct | |
| css = """ | |
| .scroll-box { | |
| height: 300px; | |
| overflow-y: auto !important; | |
| overflow-x: hidden !important; | |
| display: block !important; | |
| width: 100% !important; | |
| max-width: 100% !important; | |
| } | |
| .scroll-box * { | |
| max-width: 100% !important; | |
| box-sizing: border-box !important; | |
| } | |
| .vertical-radio { | |
| display: block !important; | |
| width: 100% !important; | |
| } | |
| .vertical-radio .wrap { | |
| display: flex !important; | |
| flex-direction: column !important; | |
| width: 100% !important; | |
| min-width: 0 !important; | |
| } | |
| .vertical-radio .gradio-radio-item { | |
| width: 100% !important; | |
| white-space: normal !important; | |
| word-break: break-all !important; | |
| } | |
| """ | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| FOLD = 5 | |
| MODEL_PATH = f"best_model_fold_{FOLD}.pt" | |
| model = DualStreamTransformer(n_feat1=24, n_feat2=10, d_model=32).to(DEVICE) | |
| metric_fc = ArcMarginProduct(32, 2).to(DEVICE) | |
| if os.path.exists(MODEL_PATH): | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model.load_state_dict(checkpoint['model']) | |
| metric_fc.load_state_dict(checkpoint['fc']) | |
| model.eval() | |
| metric_fc.eval() | |
| scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl") | |
| scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl") | |
| def analyze_and_predict(*all_answers): | |
| ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1} | |
| osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0} | |
| ccmq_ans = all_answers[:24] | |
| osdi_ans_raw = all_answers[24:] | |
| if any(a is None for a in all_answers): | |
| raise gr.Error("請完整填寫所有問卷題目!") | |
| x1_raw = np.array([[ccmq_map[a] for a in ccmq_ans]]) | |
| x2_raw = np.array([[osdi_map[a] for a in osdi_ans_raw[:10]]]) | |
| x1_scaled = scaler_ccmq.transform(x1_raw) | |
| x2_scaled = scaler_osdi.transform(x2_raw) | |
| sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE) | |
| sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE) | |
| with torch.no_grad(): | |
| feats = model(sx1, sx2) | |
| logits = metric_fc.predict(feats) | |
| probs = torch.softmax(logits, dim=1) | |
| pred_idx = torch.argmax(probs, dim=1).item() | |
| conf = probs[0, pred_idx].item() | |
| plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans'] | |
| plt.rcParams['axes.unicode_minus'] = False | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| labels = ["健康/正常", "乾眼風險"] | |
| sns.barplot(x=[conf if i == pred_idx else 1-conf for i in range(2)], y=labels, palette="viridis", ax=ax) | |
| ax.set_title(f"AI 診斷信心度: {conf:.2%}") | |
| table_data = [[f"題目 {i+1}", all_answers[i], "已記錄"] for i in range(10)] | |
| res_label = "乾眼風險 (SJS/DES)" if pred_idx == 1 else "正常/健康" | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| f"### 診斷結果:{res_label}", | |
| {"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)}, | |
| table_data, | |
| fig, | |
| fig | |
| ) | |
| def reset_system(): | |
| return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 36 | |
| with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo: | |
| gr.Markdown("# 中醫AI診斷系統") | |
| with gr.Column(visible=True) as stage_1: | |
| with gr.Tabs() as survey_tabs: | |
| with gr.Tab("CCMQ 體質評估", id=0): | |
| with gr.Group(elem_classes="scroll-box"): | |
| ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"] | |
| all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}") for i, txt in enumerate(ccmq_labels)] | |
| btn_next = gr.Button("下一步:填寫 OSDI", variant="primary") | |
| with gr.Tab("OSDI 症狀評估", id=1): | |
| with gr.Group(elem_classes="scroll-box"): | |
| gr.Markdown("#### 在過去一週中,您是否出現下列症狀?") | |
| o1 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="1. 眼睛對光敏感?") | |
| o2 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="3. 眼睛疼痛?") | |
| o3 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="4. 視線模糊?") | |
| o4 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="5.視力減退?") | |
| o5 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="6. 閱讀?") | |
| o6 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="7. 夜間駕駛?") | |
| o7 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="8. 操作電腦與提款機?") | |
| o8 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="9. 觀看電視?") | |
| o9 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="10. 刮風的狀況?") | |
| o10 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="12. 區域使用空調?") | |
| all_osdi = [o1, o2, o3, o4, o5, o6, o7, o8, o9, o10] | |
| submit_btn = gr.Button("🚀 生成診斷報告", variant="primary") | |
| with gr.Column(visible=False) as stage_2: | |
| gr.Markdown("## 📊 AI 診斷分析報告") | |
| with gr.Row(): | |
| res_table = gr.Dataframe(headers=["項目", "回答", "狀態"], interactive=False) | |
| with gr.Column(): | |
| res_prob = gr.Label(label="預測機率") | |
| res_title = gr.Markdown("### 診斷結果") | |
| res_desc = gr.Markdown("分析中...") | |
| plot_1 = gr.Plot() | |
| plot_2 = gr.Plot() | |
| finish_btn = gr.Button("結束並重新開始", size="lg", variant="secondary") | |
| # 互動邏輯 | |
| all_inputs = all_ccmq + all_osdi | |
| btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs) | |
| submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2]) | |
| finish_btn.click(fn=reset_system, outputs=[stage_1, stage_2, survey_tabs] + all_inputs) | |
| if __name__ == "__main__": | |
| demo.launch() |