PinHsuan commited on
Commit
ffd09d4
·
verified ·
1 Parent(s): 14468ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -44
app.py CHANGED
@@ -7,7 +7,6 @@ import gradio as gr
7
  import os
8
  import joblib
9
  from model import DualStreamTransformer, ArcMarginProduct
10
-
11
  css = """
12
  .scroll-box {
13
  height: 300px;
@@ -37,30 +36,42 @@ css = """
37
  word-break: break-all !important;
38
  }
39
  """
40
-
41
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
  FOLD = 5
43
- model = DualStreamTransformer(n_feat1=25, n_feat2=12, d_model=32).to(DEVICE)
 
 
44
  metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
45
- checkpoint = torch.load(f"best_model_fold_{FOLD}.pt", map_location=DEVICE)
46
- model.load_state_dict(checkpoint['model'])
47
- metric_fc.load_state_dict(checkpoint['fc'])
48
 
 
 
 
 
 
 
49
 
50
  scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl")
51
  scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")
52
 
53
- model = DualStreamTransformer(n_feat1=24, n_feat2=10, d_model=32).to(DEVICE)
54
  def analyze_and_predict(*all_answers):
55
- ccmq_ans = all_answers[:24]
56
- osdi_ans = all_answers[25:35]
 
 
 
 
 
 
 
 
 
57
 
58
  x1_scaled = scaler_ccmq.transform(x1_raw)
59
  x2_scaled = scaler_osdi.transform(x2_raw)
60
 
61
  sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE)
62
  sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE)
63
-
64
  with torch.no_grad():
65
  feats = model(sx1, sx2)
66
  logits = metric_fc.predict(feats)
@@ -69,18 +80,21 @@ def analyze_and_predict(*all_answers):
69
  conf = probs[0, pred_idx].item()
70
 
71
  plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
 
 
72
  fig, ax = plt.subplots(figsize=(6, 4))
73
- sns.barplot(x=[conf, 1-conf], y=["預測類別", "其他"], palette="viridis", ax=ax)
74
- ax.set_title(f"診斷信心度: {conf:.2%}")
 
75
 
76
- table_data = []
 
 
77
 
78
- res_label = "🔴 乾眼風險 (SJS/DES)" if pred_idx == 1 else "🟢 正常/健康"
79
  return (
80
  gr.update(visible=False),
81
  gr.update(visible=True),
82
  f"### 診斷結果:{res_label}",
83
- "根據 FT-Transformer 的注意力機制分析,您的特徵與臨床乾眼指標有顯著關連。",
84
  {"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
85
  table_data,
86
  fig,
@@ -88,9 +102,10 @@ def analyze_and_predict(*all_answers):
88
  )
89
 
90
  def reset_system():
91
- return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 37
 
92
  with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo:
93
- gr.Markdown("# 中醫 AI 診斷系統")
94
 
95
  with gr.Column(visible=True) as stage_1:
96
  with gr.Tabs() as survey_tabs:
@@ -98,46 +113,38 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflo
98
  with gr.Group(elem_classes="scroll-box"):
99
  ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"]
100
  all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}") for i, txt in enumerate(ccmq_labels)]
101
- btn_next = gr.Button("下一步", variant="primary")
102
 
103
  with gr.Tab("OSDI 症狀評估", id=1):
104
  with gr.Group(elem_classes="scroll-box"):
105
- gr.Markdown("#### A. 眼睛症狀")
106
- gr.Markdown("#### 在過去一週中,您是否出現下列任一症狀?")
107
  o1 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="1. 眼睛對光敏感?")
108
- o3 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="3. 眼睛疼痛?")
109
- o4 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="4. 視線模糊?")
110
- o5 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="5. 視力減退?")
111
-
112
- gr.Markdown("---")
113
- gr.Markdown("#### B. 日常活動限制")
114
- gr.Markdown("#### 下列活動,否因眼睛問題受到限制?")
115
- o6 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="6. 閱讀?")
116
- o7 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="7. 夜間駕駛?")
117
- o8 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="8. 操作電腦與提款機?")
118
- o9 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="9. 觀看電視?")
119
-
120
- gr.Markdown("---")
121
- gr.Markdown("#### C. 環境因素不適感")
122
- gr.Markdown("#### 在過去一週中遇到任一狀況時,您的眼睛是否曾感覺不適?")
123
- o10 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="10. 刮風的狀況?")
124
- o12 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="12. 區域使用空調?")
125
-
126
- all_osdi = [o1, o2, o3, o4, o5, o6, o7, o8, o9, o10, o11, o12]
127
- submit_btn = gr.Button("生成診斷報告", variant="primary")
128
 
129
  with gr.Column(visible=False) as stage_2:
130
- gr.Markdown("## 診斷分析報告")
131
  with gr.Row():
132
- res_table = gr.Dataframe(headers=["項目", "回答", "分值"], interactive=False)
133
  with gr.Column():
134
  res_prob = gr.Label(label="預測機率")
135
  res_title = gr.Markdown("### 診斷結果")
136
- res_desc = gr.Markdown("詳細說明...")
137
  plot_1 = gr.Plot()
138
  plot_2 = gr.Plot()
139
- finish_btn = gr.Button("結束並重新開始", size="lg")
140
 
 
141
  all_inputs = all_ccmq + all_osdi
142
  btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
143
  submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2])
 
7
  import os
8
  import joblib
9
  from model import DualStreamTransformer, ArcMarginProduct
 
10
  css = """
11
  .scroll-box {
12
  height: 300px;
 
36
  word-break: break-all !important;
37
  }
38
  """
 
39
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  FOLD = 5
41
+ MODEL_PATH = f"best_model_fold_{FOLD}.pt"
42
+
43
+ model = DualStreamTransformer(n_feat1=24, n_feat2=10, d_model=32).to(DEVICE)
44
  metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
 
 
 
45
 
46
+ if os.path.exists(MODEL_PATH):
47
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
48
+ model.load_state_dict(checkpoint['model'])
49
+ metric_fc.load_state_dict(checkpoint['fc'])
50
+ model.eval()
51
+ metric_fc.eval()
52
 
53
  scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl")
54
  scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")
55
 
 
56
  def analyze_and_predict(*all_answers):
57
+ ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
58
+ osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
59
+
60
+ ccmq_ans = all_answers[:24]
61
+ osdi_ans_raw = all_answers[24:]
62
+
63
+ if any(a is None for a in all_answers):
64
+ raise gr.Error("請完整填寫所有問卷題目!")
65
+
66
+ x1_raw = np.array([[ccmq_map[a] for a in ccmq_ans]])
67
+ x2_raw = np.array([[osdi_map[a] for a in osdi_ans_raw[:10]]])
68
 
69
  x1_scaled = scaler_ccmq.transform(x1_raw)
70
  x2_scaled = scaler_osdi.transform(x2_raw)
71
 
72
  sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE)
73
  sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE)
74
+
75
  with torch.no_grad():
76
  feats = model(sx1, sx2)
77
  logits = metric_fc.predict(feats)
 
80
  conf = probs[0, pred_idx].item()
81
 
82
  plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
83
+ plt.rcParams['axes.unicode_minus'] = False
84
+
85
  fig, ax = plt.subplots(figsize=(6, 4))
86
+ labels = ["健康/正常", "乾眼風險"]
87
+ sns.barplot(x=[conf if i == pred_idx else 1-conf for i in range(2)], y=labels, palette="viridis", ax=ax)
88
+ ax.set_title(f"AI 診斷信心度: {conf:.2%}")
89
 
90
+ table_data = [[f"題目 {i+1}", all_answers[i], "已記錄"] for i in range(10)]
91
+
92
+ res_label = "乾眼風險 (SJS/DES)" if pred_idx == 1 else "正常/健康"
93
 
 
94
  return (
95
  gr.update(visible=False),
96
  gr.update(visible=True),
97
  f"### 診斷結果:{res_label}",
 
98
  {"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
99
  table_data,
100
  fig,
 
102
  )
103
 
104
  def reset_system():
105
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 36
106
+
107
  with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo:
108
+ gr.Markdown("# 中醫AI診斷系統")
109
 
110
  with gr.Column(visible=True) as stage_1:
111
  with gr.Tabs() as survey_tabs:
 
113
  with gr.Group(elem_classes="scroll-box"):
114
  ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"]
115
  all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}") for i, txt in enumerate(ccmq_labels)]
116
+ btn_next = gr.Button("下一步:填寫 OSDI", variant="primary")
117
 
118
  with gr.Tab("OSDI 症狀評估", id=1):
119
  with gr.Group(elem_classes="scroll-box"):
120
+ gr.Markdown("#### 在過去一週中,您是否出現下列症狀")
 
121
  o1 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="1. 眼睛對光敏感?")
122
+ o2 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="3. 眼睛疼痛?")
123
+ o3 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="4. 視線模糊?")
124
+ o4 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="5.視力減退?")
125
+ o5 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="6. 閱讀?")
126
+ o6 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="7. 夜間駕駛?")
127
+ o7 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="8. 操作電腦與提款機?")
128
+ o8 = gr.Radio(["", "經常","一半一半","偶","完全不曾"], label="9. 觀看電視?")
129
+ o9 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="10. 刮風的狀況?")
130
+ o10 = gr.Radio(["總是", "經常","一半一半","偶而","完全不曾"], label="12. 區域使用空調?")
131
+ all_osdi = [o1, o2, o3, o4, o5, o6, o7, o8, o9, o10]
132
+
133
+ submit_btn = gr.Button("🚀 生成診斷報告", variant="primary")
 
 
 
 
 
 
 
 
134
 
135
  with gr.Column(visible=False) as stage_2:
136
+ gr.Markdown("## 📊 AI 診斷分析報告")
137
  with gr.Row():
138
+ res_table = gr.Dataframe(headers=["項目", "回答", "狀態"], interactive=False)
139
  with gr.Column():
140
  res_prob = gr.Label(label="預測機率")
141
  res_title = gr.Markdown("### 診斷結果")
142
+ res_desc = gr.Markdown("分析中...")
143
  plot_1 = gr.Plot()
144
  plot_2 = gr.Plot()
145
+ finish_btn = gr.Button("結束並重新開始", size="lg", variant="secondary")
146
 
147
+ # 互動邏輯
148
  all_inputs = all_ccmq + all_osdi
149
  btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
150
  submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2])