File size: 6,999 Bytes
6105886
 
 
 
 
 
 
562c23b
6105886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445065e
562c23b
ffd09d4
 
 
6105886
 
ffd09d4
 
 
 
 
 
6105886
562c23b
 
95626bd
6105886
ffd09d4
 
 
 
 
 
 
 
 
 
 
6105886
562c23b
 
 
 
 
ffd09d4
6105886
562c23b
6105886
 
 
 
 
 
ffd09d4
 
6105886
ffd09d4
 
 
6105886
ffd09d4
 
 
6105886
 
 
 
 
 
 
 
95626bd
6105886
 
 
ffd09d4
 
6105886
ffd09d4
6105886
 
 
 
 
4bd40e6
6105886
ffd09d4
6105886
 
 
ffd09d4
6105886
ffd09d4
 
 
 
 
 
 
 
 
 
 
 
6105886
 
ffd09d4
6105886
ffd09d4
6105886
 
 
ffd09d4
6105886
 
ffd09d4
6105886
ffd09d4
6105886
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
import os
import joblib
from model import DualStreamTransformer, ArcMarginProduct
css = """
.scroll-box {
    height: 300px;
    overflow-y: auto !important;
    overflow-x: hidden !important;
    display: block !important;
    width: 100% !important;
    max-width: 100% !important;
}
.scroll-box * {
    max-width: 100% !important;
    box-sizing: border-box !important;
}
.vertical-radio {
    display: block !important;
    width: 100% !important;
}
.vertical-radio .wrap {
    display: flex !important;
    flex-direction: column !important;
    width: 100% !important;
    min-width: 0 !important;
}
.vertical-radio .gradio-radio-item {
    width: 100% !important;
    white-space: normal !important;
    word-break: break-all !important;
}
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FOLD = 5
MODEL_PATH = f"best_model_fold_{FOLD}.pt"

model = DualStreamTransformer(n_feat1=24, n_feat2=10, d_model=32).to(DEVICE)
metric_fc = ArcMarginProduct(32, 2).to(DEVICE)

if os.path.exists(MODEL_PATH):
    checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model'])
    metric_fc.load_state_dict(checkpoint['fc'])
    model.eval()
    metric_fc.eval()

scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl")
scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")

def analyze_and_predict(*all_answers):
    ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
    osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}

    ccmq_ans = all_answers[:24]
    osdi_ans_raw = all_answers[24:] 
    
    if any(a is None for a in all_answers):
        raise gr.Error("請完整填寫所有問卷題目!")

    x1_raw = np.array([[ccmq_map[a] for a in ccmq_ans]])
    x2_raw = np.array([[osdi_map[a] for a in osdi_ans_raw[:10]]])

    x1_scaled = scaler_ccmq.transform(x1_raw)
    x2_scaled = scaler_osdi.transform(x2_raw)

    sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE)
    sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE)

    with torch.no_grad():
        feats = model(sx1, sx2)
        logits = metric_fc.predict(feats)
        probs = torch.softmax(logits, dim=1)
        pred_idx = torch.argmax(probs, dim=1).item()
        conf = probs[0, pred_idx].item()

    plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
    plt.rcParams['axes.unicode_minus'] = False
    
    fig, ax = plt.subplots(figsize=(6, 4))
    labels = ["健康/正常", "乾眼風險"]
    sns.barplot(x=[conf if i == pred_idx else 1-conf for i in range(2)], y=labels, palette="viridis", ax=ax)
    ax.set_title(f"AI 診斷信心度: {conf:.2%}")

    table_data = [[f"題目 {i+1}", all_answers[i], "已記錄"] for i in range(10)]
    
    res_label = "乾眼風險 (SJS/DES)" if pred_idx == 1 else "正常/健康"
    
    return (
        gr.update(visible=False), 
        gr.update(visible=True),
        f"### 診斷結果:{res_label}",
        {"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
        table_data,
        fig,
        fig 
    )

def reset_system():
    return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 36

with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo:
    gr.Markdown("# 中醫AI診斷系統")

    with gr.Column(visible=True) as stage_1:
        with gr.Tabs() as survey_tabs:
            with gr.Tab("CCMQ 體質評估", id=0):
                with gr.Group(elem_classes="scroll-box"):
                    ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"]
                    all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}") for i, txt in enumerate(ccmq_labels)]
                btn_next = gr.Button("下一步:填寫 OSDI", variant="primary")

            with gr.Tab("OSDI 症狀評估", id=1):
                with gr.Group(elem_classes="scroll-box"):
                    gr.Markdown("#### 在過去一週中,您是否出現下列症狀?")
                    o1 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="1. 眼睛對光敏感?")
                    o2 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="3. 眼睛疼痛?")
                    o3 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="4. 視線模糊?")
                    o4 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="5.視力減退?")
                    o5 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="6. 閱讀?")
                    o6 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="7. 夜間駕駛?")
                    o7 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="8. 操作電腦與提款機?")
                    o8 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="9. 觀看電視?")
                    o9 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="10. 刮風的狀況?")
                    o10 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="12. 區域使用空調?")
                    all_osdi = [o1, o2, o3, o4, o5, o6, o7, o8, o9, o10]
                
                submit_btn = gr.Button("🚀 生成診斷報告", variant="primary")

    with gr.Column(visible=False) as stage_2:
        gr.Markdown("## 📊 AI 診斷分析報告")
        with gr.Row():
            res_table = gr.Dataframe(headers=["項目", "回答", "狀態"], interactive=False)
            with gr.Column():
                res_prob = gr.Label(label="預測機率")
                res_title = gr.Markdown("### 診斷結果")
                res_desc = gr.Markdown("分析中...")
                plot_1 = gr.Plot()
                plot_2 = gr.Plot()
        finish_btn = gr.Button("結束並重新開始", size="lg", variant="secondary")

    # 互動邏輯
    all_inputs = all_ccmq + all_osdi
    btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
    submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2])
    finish_btn.click(fn=reset_system, outputs=[stage_1, stage_2, survey_tabs] + all_inputs)

if __name__ == "__main__":
    demo.launch()