File size: 5,911 Bytes
6105886
 
 
 
 
562c23b
6105886
d882bc5
39c14aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622685d
445065e
562c23b
ffd09d4
 
0a077cb
d882bc5
6105886
 
ffd09d4
 
f95ab6b
 
ffd09d4
0a077cb
562c23b
 
95626bd
6105886
ffd09d4
 
 
0a077cb
d37e6fe
 
f95ab6b
d95eecf
 
d37e6fe
0a077cb
 
6105886
0a077cb
 
 
 
 
 
2eb3982
39c14aa
ffd09d4
39c14aa
d95eecf
 
 
d882bc5
d95eecf
39c14aa
 
d95eecf
39c14aa
d95eecf
 
39c14aa
 
d95eecf
 
 
39c14aa
 
d95eecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c14aa
d95eecf
 
 
 
 
 
 
 
 
d564cef
6105886
f95ab6b
6105886
 
0a077cb
781a86c
d95eecf
d11d03d
 
 
d95eecf
d11d03d
d882bc5
d95eecf
 
 
 
d37e6fe
6105886
 
d95eecf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import pandas as pd
import numpy as np
import gradio as gr
import os
import joblib
from model import DualStreamTransformer, ArcMarginProduct

css = """
.scroll-box {
    height: 300px;
    overflow-y: auto !important;
    overflow-x: hidden !important;
    display: block !important;
    width: 100% !important;
    max-width: 100% !important;
}
.scroll-box * {
    max-width: 100% !important;
    box-sizing: border-box !important;
}
.vertical-radio {
    display: block !important;
    width: 100% !important;
}
.vertical-radio .wrap {
    display: flex !important;
    flex-direction: column !important;
    width: 100% !important;
    min-width: 0 !important;
}
.vertical-radio .gradio-radio-item {
    width: 100% !important;
    white-space: normal !important;
    word-break: break-all !important;
}
"""

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FOLD = 5
MODEL_PATH = f"best_model_fold_{FOLD}.pt"


model = DualStreamTransformer(feat_num_1=24, feat_num_2=10, d_model=32).to(DEVICE)
metric_fc = ArcMarginProduct(32, 2).to(DEVICE)

if os.path.exists(MODEL_PATH):
    checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model'], strict=False)
    metric_fc.load_state_dict(checkpoint['fc'], strict=False)
    model.eval()

scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl")
scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")

def analyze_and_predict(*all_answers):
    ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
    osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}

    try:
        x1_vals = [ccmq_map.get(a, 1) for a in all_answers[:24]]
        x2_vals = [osdi_map.get(a, 0) for a in all_answers[24:34]]
        
        x1_scaled = scaler_ccmq.transform(np.array([x1_vals]))
        x2_scaled = scaler_osdi.transform(np.array([x2_vals]))
        
        sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE)
        sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE)

        with torch.no_grad():
            feats = model(sx1, sx2)
            logits = metric_fc.predict(feats)
            probs = torch.softmax(logits, dim=1)
            pred_idx = torch.argmax(probs, dim=1).item()
            conf = probs[0, pred_idx].item()
            
        print(f"DEBUG: 推論成功! 索引: {pred_idx}, 信心度: {conf}")

        if pred_idx == 0:
            res_label = "乾眼風險 (DES)"
        else:
            res_label = "修格蘭氏症風險 (SJS)"

        prob_dict = {
            "乾眼 (DES)": probs[0, 0].item(),
            "修格蘭氏 (SJS)": probs[0, 1].item()
        }

        return (
            f"## 診斷結果:{res_label}",
            f"**分析完成**:AI 信心度為 **{conf:.2%}**。本系統已整合 CCMQ 體質與 OSDI 症狀進行二分類篩檢。",
            prob_dict
        )

    except Exception as e:
            print(f"計算出錯: {e}")
            return "### 計算出錯", f"錯誤原因: {str(e)}", {}

with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 15px; border-radius: 8px; }") as demo:
    gr.Markdown("# 中西醫 AI 診斷系統")

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("### 第一步:填寫問卷")
            with gr.Tabs() as survey_tabs:
                with gr.Tab("CCMQ 體質評估", id=0):
                    with gr.Group(elem_classes="scroll-box"):
                        ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"]
                        all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}", value="沒有") for i, txt in enumerate(ccmq_labels)]
                    btn_next = gr.Button("下一步:填寫 OSDI")

                with gr.Tab("OSDI 症狀評估", id=1):
                    with gr.Group(elem_classes="scroll-box"):
                        osdi_labels = ["1. 對光敏感", "2. 眼睛疼痛", "3. 視線模糊", "4. 視力減退", "5. 閱讀限制", "6. 夜間駕駛", "7. 電腦操作", "8. 觀看電視", "9. 刮風不適", "10. 空調不適"]
                        all_osdi = [gr.Radio(["總是", "經常", "一半一半", "偶而", "完全不曾"], label=txt, value="完全不曾") for txt in osdi_labels]
                    
                    with gr.Row():
                        back_btn = gr.Button("返回")
                        submit_btn = gr.Button(" 生成分析報告", variant="primary")


        with gr.Column(scale=1):
            gr.Markdown("### 第二步:AI 診斷結果")
            res_title = gr.Markdown("### 點擊按鈕開始分析")
            res_desc = gr.Markdown("請先完成左側問卷並點擊「生成分析報告」。")
            res_prob = gr.Label(label="模型信心分佈")
            
            gr.Markdown("---")
            reset_btn = gr.Button(" 清除重新開始")


    all_inputs = all_ccmq + all_osdi
    btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
    back_btn.click(fn=lambda: gr.Tabs(selected=0), outputs=survey_tabs)
    

    submit_btn.click(
        fn=analyze_and_predict, 
        inputs=all_inputs, 
        outputs=[res_title, res_desc, res_prob]
    )
    

    reset_btn.click(
        fn=lambda: ["### 點擊按鈕開始分析", "請先完成左側問卷。", {}] + ["沒有"]*24 + ["完全不曾"]*10,
        outputs=[res_title, res_desc, res_prob] + all_inputs
    )

if __name__ == "__main__":
    demo.launch(ssr_mode=False)