PinHsuan commited on
Commit
6105886
·
verified ·
1 Parent(s): 34c5ca2

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +160 -0
  2. best_model_fold_5.pt +3 -0
  3. model.py +53 -0
  4. requirements.txt +18 -0
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import gradio as gr
7
+ import os
8
+
9
+ # 從 model.py 匯入架構
10
+ from model import DualStreamTransformer, ArcMarginProduct
11
+
12
+ css = """
13
+ .scroll-box {
14
+ height: 300px;
15
+ overflow-y: auto !important;
16
+ overflow-x: hidden !important;
17
+ display: block !important;
18
+ width: 100% !important;
19
+ max-width: 100% !important;
20
+ }
21
+ .scroll-box * {
22
+ max-width: 100% !important;
23
+ box-sizing: border-box !important;
24
+ }
25
+ .vertical-radio {
26
+ display: block !important;
27
+ width: 100% !important;
28
+ }
29
+ .vertical-radio .wrap {
30
+ display: flex !important;
31
+ flex-direction: column !important;
32
+ width: 100% !important;
33
+ min-width: 0 !important;
34
+ }
35
+ .vertical-radio .gradio-radio-item {
36
+ width: 100% !important;
37
+ white-space: normal !important;
38
+ word-break: break-all !important;
39
+ }
40
+ """
41
+
42
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ MODEL_PATH = "./best_model_fold_5.pt"
44
+
45
+ model = DualStreamTransformer(n_feat1=25, n_feat2=12, d_model=32).to(DEVICE)
46
+ metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
47
+
48
+ if os.path.exists(MODEL_PATH):
49
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
50
+ if isinstance(checkpoint, dict) and 'model' in checkpoint:
51
+ model.load_state_dict(checkpoint['model'])
52
+ metric_fc.load_state_dict(checkpoint['fc'])
53
+ else:
54
+ model.load_state_dict(checkpoint)
55
+ model.eval()
56
+ print("模型載入成功!")
57
+
58
+ # ==========================================
59
+ # 邏輯函式
60
+ # ==========================================
61
+ def analyze_and_predict(*all_answers):
62
+ if any(a is None for a in all_answers):
63
+ raise gr.Error("請完整填寫所有問卷題目!")
64
+
65
+ ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
66
+ osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
67
+
68
+ # 資料處理
69
+ x1 = torch.tensor([[ccmq_map[a] for a in all_answers[:25]]], dtype=torch.float32).to(DEVICE)
70
+ x2 = torch.tensor([[osdi_map[a] for a in all_answers[25:]]], dtype=torch.float32).to(DEVICE)
71
+
72
+ with torch.no_grad():
73
+ feats = model(x1, x2)
74
+ logits = metric_fc.predict(feats)
75
+ probs = torch.softmax(logits, dim=1)
76
+ pred_idx = torch.argmax(probs, dim=1).item()
77
+ conf = probs[0, pred_idx].item()
78
+
79
+ # 繪圖展示 (研討會風格)
80
+ plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
81
+ fig, ax = plt.subplots(figsize=(6, 4))
82
+ sns.barplot(x=[conf, 1-conf], y=["預測類別", "其他"], palette="viridis", ax=ax)
83
+ ax.set_title(f"AI 診斷信心度: {conf:.2%}")
84
+
85
+ # 表格數據
86
+ table_data = [] # 此處可根據需求填充
87
+
88
+ res_label = "🔴 乾眼風險 (SJS/DES)" if pred_idx == 1 else "🟢 正常/健康"
89
+ return (
90
+ gr.update(visible=False),
91
+ gr.update(visible=True),
92
+ f"### 診斷結果:{res_label}",
93
+ "根據 FT-Transformer 的注意力機制分析,您的特徵與臨床乾眼指標有顯著關連。",
94
+ {"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
95
+ table_data,
96
+ fig,
97
+ fig # Demo 用,可替換為關聯圖
98
+ )
99
+
100
+ def reset_system():
101
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 37
102
+ with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo:
103
+ gr.Markdown("# 舌象與眼疾中西醫 AI 診斷系統")
104
+
105
+ with gr.Column(visible=True) as stage_1:
106
+ with gr.Tabs() as survey_tabs:
107
+ with gr.Tab("CCMQ 體質評估", id=0):
108
+ with gr.Group(elem_classes="scroll-box"):
109
+ ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","褐班","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"]
110
+ all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}") for i, txt in enumerate(ccmq_labels)]
111
+ btn_next = gr.Button("下一步", variant="primary")
112
+
113
+ with gr.Tab("OSDI 症狀評估", id=1):
114
+ with gr.Group(elem_classes="scroll-box"):
115
+ gr.Markdown("#### A. 眼睛症狀")
116
+ gr.Markdown("#### 在過去一週中,您是否出現下列任一症狀?")
117
+ o1 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="1. 眼睛對光敏感?")
118
+ o2 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="2. 眼睛有異物感?")
119
+ o3 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="3. 眼睛疼痛?")
120
+ o4 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="4. 視線模糊?")
121
+ o5 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="5. 視力減退?")
122
+
123
+ gr.Markdown("---")
124
+ gr.Markdown("#### B. 日常活動限制")
125
+ gr.Markdown("#### 下列活動,是否因眼睛問題而受到限制?")
126
+ o6 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="6. 閱讀?")
127
+ o7 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="7. 夜間駕駛?")
128
+ o8 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="8. 操作電腦與提款機?")
129
+ o9 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="9. 觀看電視?")
130
+
131
+ gr.Markdown("---")
132
+ gr.Markdown("#### C. 環境因素不適感")
133
+ gr.Markdown("#### 在過去一週中遇到任一狀況時,您的眼睛是否曾感覺不適?")
134
+ o10 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="10. 刮風的狀況?")
135
+ o11 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="11. 濕度較低?")
136
+ o12 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="12. 區域使用空調?")
137
+
138
+ all_osdi = [o1, o2, o3, o4, o5, o6, o7, o8, o9, o10, o11, o12]
139
+ submit_btn = gr.Button("生成診斷報告", variant="primary")
140
+
141
+ with gr.Column(visible=False) as stage_2:
142
+ gr.Markdown("## 診斷分析報告")
143
+ with gr.Row():
144
+ res_table = gr.Dataframe(headers=["項目", "回答", "分值"], interactive=False)
145
+ with gr.Column():
146
+ res_prob = gr.Label(label="預測機率")
147
+ res_title = gr.Markdown("### 診斷結果")
148
+ res_desc = gr.Markdown("詳細說明...")
149
+ plot_1 = gr.Plot()
150
+ plot_2 = gr.Plot()
151
+ finish_btn = gr.Button("結束並重新開始", size="lg")
152
+
153
+ # 邏輯綁定
154
+ all_inputs = all_ccmq + all_osdi
155
+ btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
156
+ submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2])
157
+ finish_btn.click(fn=reset_system, outputs=[stage_1, stage_2, survey_tabs] + all_inputs)
158
+
159
+ if __name__ == "__main__":
160
+ demo.launch()
best_model_fold_5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1cdb8693e41015708044318fbe6b50bd046463d9aea7d5a4d4d104f65ad711a
3
+ size 265683
model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class DualStreamTransformer(nn.Module):
5
+ def __init__(self, n_feat1=25, n_feat2=12, d_model=32, num_classes=2):
6
+ super(DualStreamTransformer, self).__init__()
7
+
8
+ # Stream 1: CCMQ Tokenizer & Encoder
9
+ self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat1)])
10
+ self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
11
+ encoder_layer_1 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, batch_first=True)
12
+ self.encoder_1 = nn.TransformerEncoder(encoder_layer_1, num_layers=3)
13
+
14
+ # Stream 2: OSDI Tokenizer & Encoder
15
+ self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat2)])
16
+ self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
17
+ encoder_layer_2 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, batch_first=True)
18
+ self.encoder_2 = nn.TransformerEncoder(encoder_layer_2, num_layers=3)
19
+
20
+ # Fusion 層
21
+ self.mlp_head = nn.Sequential(
22
+ nn.Linear(d_model * 2, d_model),
23
+ nn.ReLU(),
24
+ nn.Linear(d_model, d_model) # 輸出 Embedding 給 ArcMargin
25
+ )
26
+
27
+ def forward(self, x1, x2):
28
+ # Stream 1 推論
29
+ tokens1 = [layer(x1[:, i].unsqueeze(1)) for i, layer in enumerate(self.feat_tokenizers_1)]
30
+ x1_emb = torch.stack(tokens1, dim=1)
31
+ x1_emb = torch.cat((self.cls_token_1.expand(x1.size(0), -1, -1), x1_emb), dim=1)
32
+ feat1 = self.encoder_1(x1_emb)[:, 0, :] # 取 CLS token
33
+
34
+ # Stream 2 推論
35
+ tokens2 = [layer(x2[:, i].unsqueeze(1)) for i, layer in enumerate(self.feat_tokenizers_2)]
36
+ x2_emb = torch.stack(tokens2, dim=1)
37
+ x2_emb = torch.cat((self.cls_token_2.expand(x2.size(0), -1, -1), x2_emb), dim=1)
38
+ feat2 = self.encoder_2(x2_emb)[:, 0, :] # 取 CLS token
39
+
40
+ # 特徵融合
41
+ combined = torch.cat((feat1, feat2), dim=1)
42
+ return self.mlp_head(combined)
43
+
44
+ class ArcMarginProduct(nn.Module):
45
+ def __init__(self, in_features, out_features, s=30.0, m=0.5):
46
+ super(ArcMarginProduct, self).__init__()
47
+ self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
48
+ nn.init.xavier_uniform_(self.weight)
49
+
50
+ def predict(self, x):
51
+ # 推論時直接做線性映射或餘弦相似度
52
+ cosine = torch.matmul(nn.functional.normalize(x), nn.functional.normalize(self.weight).t())
53
+ return cosine
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 核心 Web 介面
2
+ gradio
3
+
4
+ # 深度學習框架
5
+ torch
6
+ torchvision
7
+
8
+ # 資料處理與數值運算
9
+ pandas
10
+ numpy
11
+ openpyxl
12
+
13
+ # 資料視覺化 (研討會展示圖表用)
14
+ matplotlib
15
+ seaborn
16
+
17
+ # 科學運算
18
+ scikit-learn