PinHsuan commited on
Commit
95626bd
·
verified ·
1 Parent(s): c39c656

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -12
app.py CHANGED
@@ -6,7 +6,6 @@ import seaborn as sns
6
  import gradio as gr
7
  import os
8
 
9
- # 從 model.py 匯入架構
10
  from model import DualStreamTransformer, ArcMarginProduct
11
 
12
  css = """
@@ -55,9 +54,7 @@ if os.path.exists(MODEL_PATH):
55
  model.eval()
56
  print("模型載入成功!")
57
 
58
- # ==========================================
59
- # 邏輯函式
60
- # ==========================================
61
  def analyze_and_predict(*all_answers):
62
  if any(a is None for a in all_answers):
63
  raise gr.Error("請完整填寫所有問卷題目!")
@@ -65,7 +62,6 @@ def analyze_and_predict(*all_answers):
65
  ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
66
  osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
67
 
68
- # 資料處理
69
  x1 = torch.tensor([[ccmq_map[a] for a in all_answers[:25]]], dtype=torch.float32).to(DEVICE)
70
  x2 = torch.tensor([[osdi_map[a] for a in all_answers[25:]]], dtype=torch.float32).to(DEVICE)
71
 
@@ -76,14 +72,12 @@ def analyze_and_predict(*all_answers):
76
  pred_idx = torch.argmax(probs, dim=1).item()
77
  conf = probs[0, pred_idx].item()
78
 
79
- # 繪圖展示 (研討會風格)
80
  plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
81
  fig, ax = plt.subplots(figsize=(6, 4))
82
  sns.barplot(x=[conf, 1-conf], y=["預測類別", "其他"], palette="viridis", ax=ax)
83
- ax.set_title(f"AI 診斷信心度: {conf:.2%}")
84
 
85
- # 表格數據
86
- table_data = [] # 此處可根據需求填充
87
 
88
  res_label = "🔴 乾眼風險 (SJS/DES)" if pred_idx == 1 else "🟢 正常/健康"
89
  return (
@@ -94,13 +88,13 @@ def analyze_and_predict(*all_answers):
94
  {"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
95
  table_data,
96
  fig,
97
- fig # Demo 用,可替換為關聯圖
98
  )
99
 
100
  def reset_system():
101
  return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 37
102
  with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo:
103
- gr.Markdown("# 舌象與眼疾西醫 AI 診斷系統")
104
 
105
  with gr.Column(visible=True) as stage_1:
106
  with gr.Tabs() as survey_tabs:
@@ -150,7 +144,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflo
150
  plot_2 = gr.Plot()
151
  finish_btn = gr.Button("結束並重新開始", size="lg")
152
 
153
- # 邏輯綁定
154
  all_inputs = all_ccmq + all_osdi
155
  btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
156
  submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2])
 
6
  import gradio as gr
7
  import os
8
 
 
9
  from model import DualStreamTransformer, ArcMarginProduct
10
 
11
  css = """
 
54
  model.eval()
55
  print("模型載入成功!")
56
 
57
+
 
 
58
  def analyze_and_predict(*all_answers):
59
  if any(a is None for a in all_answers):
60
  raise gr.Error("請完整填寫所有問卷題目!")
 
62
  ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
63
  osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
64
 
 
65
  x1 = torch.tensor([[ccmq_map[a] for a in all_answers[:25]]], dtype=torch.float32).to(DEVICE)
66
  x2 = torch.tensor([[osdi_map[a] for a in all_answers[25:]]], dtype=torch.float32).to(DEVICE)
67
 
 
72
  pred_idx = torch.argmax(probs, dim=1).item()
73
  conf = probs[0, pred_idx].item()
74
 
 
75
  plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
76
  fig, ax = plt.subplots(figsize=(6, 4))
77
  sns.barplot(x=[conf, 1-conf], y=["預測類別", "其他"], palette="viridis", ax=ax)
78
+ ax.set_title(f"診斷信心度: {conf:.2%}")
79
 
80
+ table_data = []
 
81
 
82
  res_label = "🔴 乾眼風險 (SJS/DES)" if pred_idx == 1 else "🟢 正常/健康"
83
  return (
 
88
  {"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
89
  table_data,
90
  fig,
91
+ fig
92
  )
93
 
94
  def reset_system():
95
  return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 37
96
  with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo:
97
+ gr.Markdown("# 中醫 AI 診斷系統")
98
 
99
  with gr.Column(visible=True) as stage_1:
100
  with gr.Tabs() as survey_tabs:
 
144
  plot_2 = gr.Plot()
145
  finish_btn = gr.Button("結束並重新開始", size="lg")
146
 
 
147
  all_inputs = all_ccmq + all_osdi
148
  btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
149
  submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2])