import torch import pandas as pd import numpy as np import gradio as gr import os import joblib from model import DualStreamTransformer, ArcMarginProduct css = """ .scroll-box { height: 300px; overflow-y: auto !important; overflow-x: hidden !important; display: block !important; width: 100% !important; max-width: 100% !important; } .scroll-box * { max-width: 100% !important; box-sizing: border-box !important; } .vertical-radio { display: block !important; width: 100% !important; } .vertical-radio .wrap { display: flex !important; flex-direction: column !important; width: 100% !important; min-width: 0 !important; } .vertical-radio .gradio-radio-item { width: 100% !important; white-space: normal !important; word-break: break-all !important; } """ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") FOLD = 5 MODEL_PATH = f"best_model_fold_{FOLD}.pt" model = DualStreamTransformer(feat_num_1=24, feat_num_2=10, d_model=32).to(DEVICE) metric_fc = ArcMarginProduct(32, 2).to(DEVICE) if os.path.exists(MODEL_PATH): checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) model.load_state_dict(checkpoint['model'], strict=False) metric_fc.load_state_dict(checkpoint['fc'], strict=False) model.eval() scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl") scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl") def analyze_and_predict(*all_answers): ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1} osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0} try: x1_vals = [ccmq_map.get(a, 1) for a in all_answers[:24]] x2_vals = [osdi_map.get(a, 0) for a in all_answers[24:34]] x1_scaled = scaler_ccmq.transform(np.array([x1_vals])) x2_scaled = scaler_osdi.transform(np.array([x2_vals])) sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE) sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE) with torch.no_grad(): feats = model(sx1, sx2) logits = metric_fc.predict(feats) probs = torch.softmax(logits, dim=1) pred_idx = torch.argmax(probs, dim=1).item() conf = probs[0, pred_idx].item() print(f"DEBUG: 推論成功! 索引: {pred_idx}, 信心度: {conf}") if pred_idx == 0: res_label = "乾眼風險 (DES)" else: res_label = "修格蘭氏症風險 (SJS)" prob_dict = { "乾眼 (DES)": probs[0, 0].item(), "修格蘭氏 (SJS)": probs[0, 1].item() } return ( f"## 診斷結果:{res_label}", f"**分析完成**:AI 信心度為 **{conf:.2%}**。本系統已整合 CCMQ 體質與 OSDI 症狀進行二分類篩檢。", prob_dict ) except Exception as e: print(f"計算出錯: {e}") return "### 計算出錯", f"錯誤原因: {str(e)}", {} with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 15px; border-radius: 8px; }") as demo: gr.Markdown("# 中西醫 AI 診斷系統") with gr.Row(): with gr.Column(scale=2): gr.Markdown("### 第一步:填寫問卷") with gr.Tabs() as survey_tabs: with gr.Tab("CCMQ 體質評估", id=0): with gr.Group(elem_classes="scroll-box"): ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"] all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}", value="沒有") for i, txt in enumerate(ccmq_labels)] btn_next = gr.Button("下一步:填寫 OSDI") with gr.Tab("OSDI 症狀評估", id=1): with gr.Group(elem_classes="scroll-box"): osdi_labels = ["1. 對光敏感", "2. 眼睛疼痛", "3. 視線模糊", "4. 視力減退", "5. 閱讀限制", "6. 夜間駕駛", "7. 電腦操作", "8. 觀看電視", "9. 刮風不適", "10. 空調不適"] all_osdi = [gr.Radio(["總是", "經常", "一半一半", "偶而", "完全不曾"], label=txt, value="完全不曾") for txt in osdi_labels] with gr.Row(): back_btn = gr.Button("返回") submit_btn = gr.Button(" 生成分析報告", variant="primary") with gr.Column(scale=1): gr.Markdown("### 第二步:AI 診斷結果") res_title = gr.Markdown("### 點擊按鈕開始分析") res_desc = gr.Markdown("請先完成左側問卷並點擊「生成分析報告」。") res_prob = gr.Label(label="模型信心分佈") gr.Markdown("---") reset_btn = gr.Button(" 清除重新開始") all_inputs = all_ccmq + all_osdi btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs) back_btn.click(fn=lambda: gr.Tabs(selected=0), outputs=survey_tabs) submit_btn.click( fn=analyze_and_predict, inputs=all_inputs, outputs=[res_title, res_desc, res_prob] ) reset_btn.click( fn=lambda: ["### 點擊按鈕開始分析", "請先完成左側問卷。", {}] + ["沒有"]*24 + ["完全不曾"]*10, outputs=[res_title, res_desc, res_prob] + all_inputs ) if __name__ == "__main__": demo.launch(ssr_mode=False)