Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,6 @@ import seaborn as sns
|
|
| 6 |
import gradio as gr
|
| 7 |
import os
|
| 8 |
|
| 9 |
-
# 從 model.py 匯入架構
|
| 10 |
from model import DualStreamTransformer, ArcMarginProduct
|
| 11 |
|
| 12 |
css = """
|
|
@@ -55,9 +54,7 @@ if os.path.exists(MODEL_PATH):
|
|
| 55 |
model.eval()
|
| 56 |
print("模型載入成功!")
|
| 57 |
|
| 58 |
-
|
| 59 |
-
# 邏輯函式
|
| 60 |
-
# ==========================================
|
| 61 |
def analyze_and_predict(*all_answers):
|
| 62 |
if any(a is None for a in all_answers):
|
| 63 |
raise gr.Error("請完整填寫所有問卷題目!")
|
|
@@ -65,7 +62,6 @@ def analyze_and_predict(*all_answers):
|
|
| 65 |
ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
|
| 66 |
osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
|
| 67 |
|
| 68 |
-
# 資料處理
|
| 69 |
x1 = torch.tensor([[ccmq_map[a] for a in all_answers[:25]]], dtype=torch.float32).to(DEVICE)
|
| 70 |
x2 = torch.tensor([[osdi_map[a] for a in all_answers[25:]]], dtype=torch.float32).to(DEVICE)
|
| 71 |
|
|
@@ -76,14 +72,12 @@ def analyze_and_predict(*all_answers):
|
|
| 76 |
pred_idx = torch.argmax(probs, dim=1).item()
|
| 77 |
conf = probs[0, pred_idx].item()
|
| 78 |
|
| 79 |
-
# 繪圖展示 (研討會風格)
|
| 80 |
plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
|
| 81 |
fig, ax = plt.subplots(figsize=(6, 4))
|
| 82 |
sns.barplot(x=[conf, 1-conf], y=["預測類別", "其他"], palette="viridis", ax=ax)
|
| 83 |
-
ax.set_title(f"
|
| 84 |
|
| 85 |
-
|
| 86 |
-
table_data = [] # 此處可根據需求填充
|
| 87 |
|
| 88 |
res_label = "🔴 乾眼風險 (SJS/DES)" if pred_idx == 1 else "🟢 正常/健康"
|
| 89 |
return (
|
|
@@ -94,13 +88,13 @@ def analyze_and_predict(*all_answers):
|
|
| 94 |
{"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
|
| 95 |
table_data,
|
| 96 |
fig,
|
| 97 |
-
fig
|
| 98 |
)
|
| 99 |
|
| 100 |
def reset_system():
|
| 101 |
return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 37
|
| 102 |
with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo:
|
| 103 |
-
gr.Markdown("#
|
| 104 |
|
| 105 |
with gr.Column(visible=True) as stage_1:
|
| 106 |
with gr.Tabs() as survey_tabs:
|
|
@@ -150,7 +144,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflo
|
|
| 150 |
plot_2 = gr.Plot()
|
| 151 |
finish_btn = gr.Button("結束並重新開始", size="lg")
|
| 152 |
|
| 153 |
-
# 邏輯綁定
|
| 154 |
all_inputs = all_ccmq + all_osdi
|
| 155 |
btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
|
| 156 |
submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2])
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import os
|
| 8 |
|
|
|
|
| 9 |
from model import DualStreamTransformer, ArcMarginProduct
|
| 10 |
|
| 11 |
css = """
|
|
|
|
| 54 |
model.eval()
|
| 55 |
print("模型載入成功!")
|
| 56 |
|
| 57 |
+
|
|
|
|
|
|
|
| 58 |
def analyze_and_predict(*all_answers):
|
| 59 |
if any(a is None for a in all_answers):
|
| 60 |
raise gr.Error("請完整填寫所有問卷題目!")
|
|
|
|
| 62 |
ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
|
| 63 |
osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
|
| 64 |
|
|
|
|
| 65 |
x1 = torch.tensor([[ccmq_map[a] for a in all_answers[:25]]], dtype=torch.float32).to(DEVICE)
|
| 66 |
x2 = torch.tensor([[osdi_map[a] for a in all_answers[25:]]], dtype=torch.float32).to(DEVICE)
|
| 67 |
|
|
|
|
| 72 |
pred_idx = torch.argmax(probs, dim=1).item()
|
| 73 |
conf = probs[0, pred_idx].item()
|
| 74 |
|
|
|
|
| 75 |
plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'DejaVu Sans']
|
| 76 |
fig, ax = plt.subplots(figsize=(6, 4))
|
| 77 |
sns.barplot(x=[conf, 1-conf], y=["預測類別", "其他"], palette="viridis", ax=ax)
|
| 78 |
+
ax.set_title(f"診斷信心度: {conf:.2%}")
|
| 79 |
|
| 80 |
+
table_data = []
|
|
|
|
| 81 |
|
| 82 |
res_label = "🔴 乾眼風險 (SJS/DES)" if pred_idx == 1 else "🟢 正常/健康"
|
| 83 |
return (
|
|
|
|
| 88 |
{"風險機率": conf if pred_idx==1 else 1-conf, "健康程度": 1 - (conf if pred_idx==1 else 1-conf)},
|
| 89 |
table_data,
|
| 90 |
fig,
|
| 91 |
+
fig
|
| 92 |
)
|
| 93 |
|
| 94 |
def reset_system():
|
| 95 |
return [gr.update(visible=True), gr.update(visible=False), gr.update(selected=0)] + [None] * 37
|
| 96 |
with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 450px; overflow-y: auto; }") as demo:
|
| 97 |
+
gr.Markdown("# 中醫 AI 診斷系統")
|
| 98 |
|
| 99 |
with gr.Column(visible=True) as stage_1:
|
| 100 |
with gr.Tabs() as survey_tabs:
|
|
|
|
| 144 |
plot_2 = gr.Plot()
|
| 145 |
finish_btn = gr.Button("結束並重新開始", size="lg")
|
| 146 |
|
|
|
|
| 147 |
all_inputs = all_ccmq + all_osdi
|
| 148 |
btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
|
| 149 |
submit_btn.click(fn=analyze_and_predict, inputs=all_inputs, outputs=[stage_1, stage_2, res_title, res_desc, res_prob, res_table, plot_1, plot_2])
|