Spaces:
Sleeping
Sleeping
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import gradio as gr | |
| import os | |
| import joblib | |
| from model import DualStreamTransformer, ArcMarginProduct | |
| css = """ | |
| .scroll-box { | |
| height: 300px; | |
| overflow-y: auto !important; | |
| overflow-x: hidden !important; | |
| display: block !important; | |
| width: 100% !important; | |
| max-width: 100% !important; | |
| } | |
| .scroll-box * { | |
| max-width: 100% !important; | |
| box-sizing: border-box !important; | |
| } | |
| .vertical-radio { | |
| display: block !important; | |
| width: 100% !important; | |
| } | |
| .vertical-radio .wrap { | |
| display: flex !important; | |
| flex-direction: column !important; | |
| width: 100% !important; | |
| min-width: 0 !important; | |
| } | |
| .vertical-radio .gradio-radio-item { | |
| width: 100% !important; | |
| white-space: normal !important; | |
| word-break: break-all !important; | |
| } | |
| """ | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| FOLD = 5 | |
| MODEL_PATH = f"best_model_fold_{FOLD}.pt" | |
| model = DualStreamTransformer(feat_num_1=24, feat_num_2=10, d_model=32).to(DEVICE) | |
| metric_fc = ArcMarginProduct(32, 2).to(DEVICE) | |
| if os.path.exists(MODEL_PATH): | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model.load_state_dict(checkpoint['model'], strict=False) | |
| metric_fc.load_state_dict(checkpoint['fc'], strict=False) | |
| model.eval() | |
| scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl") | |
| scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl") | |
| def analyze_and_predict(*all_answers): | |
| ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1} | |
| osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0} | |
| try: | |
| x1_vals = [ccmq_map.get(a, 1) for a in all_answers[:24]] | |
| x2_vals = [osdi_map.get(a, 0) for a in all_answers[24:34]] | |
| x1_scaled = scaler_ccmq.transform(np.array([x1_vals])) | |
| x2_scaled = scaler_osdi.transform(np.array([x2_vals])) | |
| sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE) | |
| sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE) | |
| with torch.no_grad(): | |
| feats = model(sx1, sx2) | |
| logits = metric_fc.predict(feats) | |
| probs = torch.softmax(logits, dim=1) | |
| pred_idx = torch.argmax(probs, dim=1).item() | |
| conf = probs[0, pred_idx].item() | |
| print(f"DEBUG: 推論成功! 索引: {pred_idx}, 信心度: {conf}") | |
| if pred_idx == 0: | |
| res_label = "乾眼風險 (DES)" | |
| else: | |
| res_label = "修格蘭氏症風險 (SJS)" | |
| prob_dict = { | |
| "乾眼 (DES)": probs[0, 0].item(), | |
| "修格蘭氏 (SJS)": probs[0, 1].item() | |
| } | |
| return ( | |
| f"## 診斷結果:{res_label}", | |
| f"**分析完成**:AI 信心度為 **{conf:.2%}**。本系統已整合 CCMQ 體質與 OSDI 症狀進行二分類篩檢。", | |
| prob_dict | |
| ) | |
| except Exception as e: | |
| print(f"計算出錯: {e}") | |
| return "### 計算出錯", f"錯誤原因: {str(e)}", {} | |
| with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 15px; border-radius: 8px; }") as demo: | |
| gr.Markdown("# 中西醫 AI 診斷系統") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 第一步:填寫問卷") | |
| with gr.Tabs() as survey_tabs: | |
| with gr.Tab("CCMQ 體質評估", id=0): | |
| with gr.Group(elem_classes="scroll-box"): | |
| ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"] | |
| all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}", value="沒有") for i, txt in enumerate(ccmq_labels)] | |
| btn_next = gr.Button("下一步:填寫 OSDI") | |
| with gr.Tab("OSDI 症狀評估", id=1): | |
| with gr.Group(elem_classes="scroll-box"): | |
| osdi_labels = ["1. 對光敏感", "2. 眼睛疼痛", "3. 視線模糊", "4. 視力減退", "5. 閱讀限制", "6. 夜間駕駛", "7. 電腦操作", "8. 觀看電視", "9. 刮風不適", "10. 空調不適"] | |
| all_osdi = [gr.Radio(["總是", "經常", "一半一半", "偶而", "完全不曾"], label=txt, value="完全不曾") for txt in osdi_labels] | |
| with gr.Row(): | |
| back_btn = gr.Button("返回") | |
| submit_btn = gr.Button(" 生成分析報告", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 第二步:AI 診斷結果") | |
| res_title = gr.Markdown("### 點擊按鈕開始分析") | |
| res_desc = gr.Markdown("請先完成左側問卷並點擊「生成分析報告」。") | |
| res_prob = gr.Label(label="模型信心分佈") | |
| gr.Markdown("---") | |
| reset_btn = gr.Button(" 清除重新開始") | |
| all_inputs = all_ccmq + all_osdi | |
| btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs) | |
| back_btn.click(fn=lambda: gr.Tabs(selected=0), outputs=survey_tabs) | |
| submit_btn.click( | |
| fn=analyze_and_predict, | |
| inputs=all_inputs, | |
| outputs=[res_title, res_desc, res_prob] | |
| ) | |
| reset_btn.click( | |
| fn=lambda: ["### 點擊按鈕開始分析", "請先完成左側問卷。", {}] + ["沒有"]*24 + ["完全不曾"]*10, | |
| outputs=[res_title, res_desc, res_prob] + all_inputs | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |