Spaces:
Sleeping
Sleeping
File size: 5,911 Bytes
6105886 562c23b 6105886 d882bc5 39c14aa 622685d 445065e 562c23b ffd09d4 0a077cb d882bc5 6105886 ffd09d4 f95ab6b ffd09d4 0a077cb 562c23b 95626bd 6105886 ffd09d4 0a077cb d37e6fe f95ab6b d95eecf d37e6fe 0a077cb 6105886 0a077cb 2eb3982 39c14aa ffd09d4 39c14aa d95eecf d882bc5 d95eecf 39c14aa d95eecf 39c14aa d95eecf 39c14aa d95eecf 39c14aa d95eecf 39c14aa d95eecf d564cef 6105886 f95ab6b 6105886 0a077cb 781a86c d95eecf d11d03d d95eecf d11d03d d882bc5 d95eecf d37e6fe 6105886 d95eecf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import torch
import pandas as pd
import numpy as np
import gradio as gr
import os
import joblib
from model import DualStreamTransformer, ArcMarginProduct
css = """
.scroll-box {
height: 300px;
overflow-y: auto !important;
overflow-x: hidden !important;
display: block !important;
width: 100% !important;
max-width: 100% !important;
}
.scroll-box * {
max-width: 100% !important;
box-sizing: border-box !important;
}
.vertical-radio {
display: block !important;
width: 100% !important;
}
.vertical-radio .wrap {
display: flex !important;
flex-direction: column !important;
width: 100% !important;
min-width: 0 !important;
}
.vertical-radio .gradio-radio-item {
width: 100% !important;
white-space: normal !important;
word-break: break-all !important;
}
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FOLD = 5
MODEL_PATH = f"best_model_fold_{FOLD}.pt"
model = DualStreamTransformer(feat_num_1=24, feat_num_2=10, d_model=32).to(DEVICE)
metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
if os.path.exists(MODEL_PATH):
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint['model'], strict=False)
metric_fc.load_state_dict(checkpoint['fc'], strict=False)
model.eval()
scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl")
scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")
def analyze_and_predict(*all_answers):
ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
try:
x1_vals = [ccmq_map.get(a, 1) for a in all_answers[:24]]
x2_vals = [osdi_map.get(a, 0) for a in all_answers[24:34]]
x1_scaled = scaler_ccmq.transform(np.array([x1_vals]))
x2_scaled = scaler_osdi.transform(np.array([x2_vals]))
sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE)
sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE)
with torch.no_grad():
feats = model(sx1, sx2)
logits = metric_fc.predict(feats)
probs = torch.softmax(logits, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
conf = probs[0, pred_idx].item()
print(f"DEBUG: 推論成功! 索引: {pred_idx}, 信心度: {conf}")
if pred_idx == 0:
res_label = "乾眼風險 (DES)"
else:
res_label = "修格蘭氏症風險 (SJS)"
prob_dict = {
"乾眼 (DES)": probs[0, 0].item(),
"修格蘭氏 (SJS)": probs[0, 1].item()
}
return (
f"## 診斷結果:{res_label}",
f"**分析完成**:AI 信心度為 **{conf:.2%}**。本系統已整合 CCMQ 體質與 OSDI 症狀進行二分類篩檢。",
prob_dict
)
except Exception as e:
print(f"計算出錯: {e}")
return "### 計算出錯", f"錯誤原因: {str(e)}", {}
with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 15px; border-radius: 8px; }") as demo:
gr.Markdown("# 中西醫 AI 診斷系統")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### 第一步:填寫問卷")
with gr.Tabs() as survey_tabs:
with gr.Tab("CCMQ 體質評估", id=0):
with gr.Group(elem_classes="scroll-box"):
ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"]
all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}", value="沒有") for i, txt in enumerate(ccmq_labels)]
btn_next = gr.Button("下一步:填寫 OSDI")
with gr.Tab("OSDI 症狀評估", id=1):
with gr.Group(elem_classes="scroll-box"):
osdi_labels = ["1. 對光敏感", "2. 眼睛疼痛", "3. 視線模糊", "4. 視力減退", "5. 閱讀限制", "6. 夜間駕駛", "7. 電腦操作", "8. 觀看電視", "9. 刮風不適", "10. 空調不適"]
all_osdi = [gr.Radio(["總是", "經常", "一半一半", "偶而", "完全不曾"], label=txt, value="完全不曾") for txt in osdi_labels]
with gr.Row():
back_btn = gr.Button("返回")
submit_btn = gr.Button(" 生成分析報告", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### 第二步:AI 診斷結果")
res_title = gr.Markdown("### 點擊按鈕開始分析")
res_desc = gr.Markdown("請先完成左側問卷並點擊「生成分析報告」。")
res_prob = gr.Label(label="模型信心分佈")
gr.Markdown("---")
reset_btn = gr.Button(" 清除重新開始")
all_inputs = all_ccmq + all_osdi
btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
back_btn.click(fn=lambda: gr.Tabs(selected=0), outputs=survey_tabs)
submit_btn.click(
fn=analyze_and_predict,
inputs=all_inputs,
outputs=[res_title, res_desc, res_prob]
)
reset_btn.click(
fn=lambda: ["### 點擊按鈕開始分析", "請先完成左側問卷。", {}] + ["沒有"]*24 + ["完全不曾"]*10,
outputs=[res_title, res_desc, res_prob] + all_inputs
)
if __name__ == "__main__":
demo.launch(ssr_mode=False) |