PinHsuan commited on
Commit
781a86c
·
verified ·
1 Parent(s): fb34f6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -34
app.py CHANGED
@@ -1,8 +1,6 @@
1
  import torch
2
  import pandas as pd
3
  import numpy as np
4
- import matplotlib.pyplot as plt
5
- import seaborn as sns
6
  import gradio as gr
7
  import os
8
  import joblib
@@ -37,15 +35,10 @@ css = """
37
  word-break: break-all !important;
38
  }
39
  """
40
-
41
-
42
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  FOLD = 5
44
  MODEL_PATH = f"best_model_fold_{FOLD}.pt"
45
 
46
- plt.rcParams['font.sans-serif'] = ['Noto Sans CJK TC', 'Droid Sans Fallback', 'Arial Unicode MS']
47
- plt.rcParams['axes.unicode_minus'] = False
48
-
49
  model = DualStreamTransformer(feat_num_1=24, feat_num_2=10, d_model=32).to(DEVICE)
50
  metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
51
 
@@ -61,14 +54,11 @@ scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")
61
 
62
  def analyze_and_predict(*all_answers):
63
 
64
- if any(a is None for a in all_answers):
65
- missing = [i+1 for i, a in enumerate(all_answers) if a is None]
66
- raise gr.Error(f"還有題目沒填完!索引:{missing}")
67
  ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
68
  osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
69
 
70
  ccmq_ans = all_answers[:24]
71
- osdi_ans = all_answers[24:34] # 確保只取 10 題
72
 
73
  x1_raw = np.array([[ccmq_map[a] for a in ccmq_ans]])
74
  x2_raw = np.array([[osdi_map[a] for a in osdi_ans]])
@@ -85,31 +75,34 @@ def analyze_and_predict(*all_answers):
85
  pred_idx = torch.argmax(probs, dim=1).item()
86
  conf = probs[0, pred_idx].item()
87
 
88
-
89
- plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
90
- plt.rcParams['axes.unicode_minus'] = False
91
- fig, ax = plt.subplots(figsize=(6, 4))
92
- sns.barplot(x=[conf if i == pred_idx else 1-conf for i in range(2)], y=["健康", "風險"], palette="viridis", ax=ax)
93
- ax.set_title(f"AI 診斷信心度: {conf:.2%}")
94
-
95
- table_data = [[f"問卷項目 {i+1}", all_answers[i], "OK"] for i in range(len(all_answers))]
96
- res_label = "🔴 乾眼風險 (SJS/DES)" if pred_idx == 1 else "🟢 正常/健康"
 
 
 
97
 
98
  return (
99
  gr.update(visible=False),
100
  gr.update(visible=True),
101
  f"### {res_label}",
102
- "分析報告:系統已根據您的中醫體質與西醫症狀完成多模態融合計算。",
103
  {"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
104
  table_data,
105
- fig,
106
- fig
107
  )
108
 
109
  def reset_system():
110
  return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 34
111
 
112
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
113
  gr.Markdown("# 中醫 AI 診斷系統")
114
 
115
  with gr.Column(visible=True) as stage_1:
@@ -126,29 +119,27 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
126
  all_osdi = [gr.Radio(["總是", "經常", "一半一半", "偶而", "完全不曾"], label=txt) for txt in osdi_labels]
127
 
128
  with gr.Row():
129
- back_to_ccmq = gr.Button("返回 CCMQ")
130
- submit_btn = gr.Button("🚀 生成分析報告", variant="primary")
131
 
132
  with gr.Column(visible=False) as stage_2:
133
  gr.Markdown("## 診斷報告結果")
134
  with gr.Row():
135
  with gr.Column(scale=1):
136
- res_table = gr.Dataframe(headers=["項目", "回答", "狀態"], interactive=False)
137
- back_to_edit = gr.Button("修改問卷")
138
  with gr.Column(scale=1):
139
  res_prob = gr.Label(label="預測機率")
140
  res_title = gr.Markdown("### 診斷結果")
141
- res_desc = gr.Markdown("報告說明")
142
- plot_1 = gr.Plot()
143
- plot_2 = gr.Plot()
 
144
  finish_btn = gr.Button("結束並重新開始", variant="secondary")
145
 
146
  all_inputs = all_ccmq + all_osdi
147
-
148
  btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
149
  back_to_ccmq.click(fn=lambda: gr.Tabs(selected=0), outputs=survey_tabs)
150
- back_to_edit.click(fn=lambda: (gr.update(visible=True), gr.update(visible=False)), outputs=[stage_1, stage_2])
151
-
152
  submit_btn.click(
153
  fn=analyze_and_predict,
154
  inputs=all_inputs,
 
1
  import torch
2
  import pandas as pd
3
  import numpy as np
 
 
4
  import gradio as gr
5
  import os
6
  import joblib
 
35
  word-break: break-all !important;
36
  }
37
  """
 
 
38
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  FOLD = 5
40
  MODEL_PATH = f"best_model_fold_{FOLD}.pt"
41
 
 
 
 
42
  model = DualStreamTransformer(feat_num_1=24, feat_num_2=10, d_model=32).to(DEVICE)
43
  metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
44
 
 
54
 
55
  def analyze_and_predict(*all_answers):
56
 
 
 
 
57
  ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
58
  osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
59
 
60
  ccmq_ans = all_answers[:24]
61
+ osdi_ans = all_answers[24:34]
62
 
63
  x1_raw = np.array([[ccmq_map[a] for a in ccmq_ans]])
64
  x2_raw = np.array([[osdi_map[a] for a in osdi_ans]])
 
75
  pred_idx = torch.argmax(probs, dim=1).item()
76
  conf = probs[0, pred_idx].item()
77
 
78
+ # 準備純文字回傳內容
79
+ table_data = [[f"問卷項目 {i+1}", all_answers[i], "已記錄"] for i in range(len(all_answers))]
80
+ res_label = " 乾眼風險 (SJS/DES)" if pred_idx == 1 else " 正常/健康"
81
+
82
+
83
+ detail_text = f"""
84
+ ### 🧬 AI 模型分析詳情
85
+ - **診斷信心度**:{conf:.2%}
86
+ - **預測類別**:{res_label}
87
+ - **核心演算法**:Dual-Stream FT-Transformer
88
+ - **數據來源**:中醫體質辨識量表 (24項) + OSDI 症狀量表 (10項)
89
+ """
90
 
91
  return (
92
  gr.update(visible=False),
93
  gr.update(visible=True),
94
  f"### {res_label}",
95
+ detail_text,
96
  {"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
97
  table_data,
98
+ gr.update(visible=False),
99
+ gr.update(visible=False)
100
  )
101
 
102
  def reset_system():
103
  return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 34
104
 
105
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
106
  gr.Markdown("# 中醫 AI 診斷系統")
107
 
108
  with gr.Column(visible=True) as stage_1:
 
119
  all_osdi = [gr.Radio(["總是", "經常", "一半一半", "偶而", "完全不曾"], label=txt) for txt in osdi_labels]
120
 
121
  with gr.Row():
122
+ back_to_ccmq = gr.Button("返回")
123
+ submit_btn = gr.Button("🚀 生成診斷報告", variant="primary")
124
 
125
  with gr.Column(visible=False) as stage_2:
126
  gr.Markdown("## 診斷報告結果")
127
  with gr.Row():
128
  with gr.Column(scale=1):
129
+ res_table = gr.Dataframe(headers=["項目", "回答", "狀態"], interactive=False,elem_classes="scroll-box")
 
130
  with gr.Column(scale=1):
131
  res_prob = gr.Label(label="預測機率")
132
  res_title = gr.Markdown("### 診斷結果")
133
+ res_desc = gr.Markdown("分析中...")
134
+ plot_1 = gr.Plot(visible=False)
135
+ plot_2 = gr.Plot(visible=False)
136
+
137
  finish_btn = gr.Button("結束並重新開始", variant="secondary")
138
 
139
  all_inputs = all_ccmq + all_osdi
 
140
  btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
141
  back_to_ccmq.click(fn=lambda: gr.Tabs(selected=0), outputs=survey_tabs)
142
+
 
143
  submit_btn.click(
144
  fn=analyze_and_predict,
145
  inputs=all_inputs,