PinHsuan's picture
Update app.py
d564cef verified
import torch
import pandas as pd
import numpy as np
import gradio as gr
import os
import joblib
from model import DualStreamTransformer, ArcMarginProduct
css = """
.scroll-box {
height: 300px;
overflow-y: auto !important;
overflow-x: hidden !important;
display: block !important;
width: 100% !important;
max-width: 100% !important;
}
.scroll-box * {
max-width: 100% !important;
box-sizing: border-box !important;
}
.vertical-radio {
display: block !important;
width: 100% !important;
}
.vertical-radio .wrap {
display: flex !important;
flex-direction: column !important;
width: 100% !important;
min-width: 0 !important;
}
.vertical-radio .gradio-radio-item {
width: 100% !important;
white-space: normal !important;
word-break: break-all !important;
}
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FOLD = 5
MODEL_PATH = f"best_model_fold_{FOLD}.pt"
model = DualStreamTransformer(feat_num_1=24, feat_num_2=10, d_model=32).to(DEVICE)
metric_fc = ArcMarginProduct(32, 2).to(DEVICE)
if os.path.exists(MODEL_PATH):
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint['model'], strict=False)
metric_fc.load_state_dict(checkpoint['fc'], strict=False)
model.eval()
scaler_ccmq = joblib.load(f"scaler_ccmq_fold_{FOLD}.pkl")
scaler_osdi = joblib.load(f"scaler_osdi_fold_{FOLD}.pkl")
def analyze_and_predict(*all_answers):
ccmq_map = {"總是": 5, "經常": 4, "有時": 3, "很少": 2, "沒有": 1}
osdi_map = {"總是": 4, "經常": 3, "一半一半": 2, "偶而": 1, "完全不曾": 0}
try:
x1_vals = [ccmq_map.get(a, 1) for a in all_answers[:24]]
x2_vals = [osdi_map.get(a, 0) for a in all_answers[24:34]]
x1_scaled = scaler_ccmq.transform(np.array([x1_vals]))
x2_scaled = scaler_osdi.transform(np.array([x2_vals]))
sx1 = torch.tensor(x1_scaled, dtype=torch.float32).to(DEVICE)
sx2 = torch.tensor(x2_scaled, dtype=torch.float32).to(DEVICE)
with torch.no_grad():
feats = model(sx1, sx2)
logits = metric_fc.predict(feats)
probs = torch.softmax(logits, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
conf = probs[0, pred_idx].item()
print(f"DEBUG: 推論成功! 索引: {pred_idx}, 信心度: {conf}")
if pred_idx == 0:
res_label = "乾眼風險 (DES)"
else:
res_label = "修格蘭氏症風險 (SJS)"
prob_dict = {
"乾眼 (DES)": probs[0, 0].item(),
"修格蘭氏 (SJS)": probs[0, 1].item()
}
return (
f"## 診斷結果:{res_label}",
f"**分析完成**:AI 信心度為 **{conf:.2%}**。本系統已整合 CCMQ 體質與 OSDI 症狀進行二分類篩檢。",
prob_dict
)
except Exception as e:
print(f"計算出錯: {e}")
return "### 計算出錯", f"錯誤原因: {str(e)}", {}
with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 15px; border-radius: 8px; }") as demo:
gr.Markdown("# 中西醫 AI 診斷系統")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### 第一步:填寫問卷")
with gr.Tabs() as survey_tabs:
with gr.Tab("CCMQ 體質評估", id=0):
with gr.Group(elem_classes="scroll-box"):
ccmq_labels = ["惡寒惡風", "自汗", "胸悶腹脹","咽喉痰梗感","多愁善感","易受驚","面部暗沉","黑眼圈","健忘","唇色暗","身熱、面熱","膚乾口乾","唇紅","便祕","兩顴紅","眼乾澀","四肢冷","惡寒、腰膝冷","飲冷腹瀉","口苦口臭","帶下色黃/下陰潮濕","鼻塞流涕","變天咳喘","過敏"]
all_ccmq = [gr.Radio(["總是", "經常", "有時", "很少", "沒有"], label=f"{i+1}. {txt}", value="沒有") for i, txt in enumerate(ccmq_labels)]
btn_next = gr.Button("下一步:填寫 OSDI")
with gr.Tab("OSDI 症狀評估", id=1):
with gr.Group(elem_classes="scroll-box"):
osdi_labels = ["1. 對光敏感", "2. 眼睛疼痛", "3. 視線模糊", "4. 視力減退", "5. 閱讀限制", "6. 夜間駕駛", "7. 電腦操作", "8. 觀看電視", "9. 刮風不適", "10. 空調不適"]
all_osdi = [gr.Radio(["總是", "經常", "一半一半", "偶而", "完全不曾"], label=txt, value="完全不曾") for txt in osdi_labels]
with gr.Row():
back_btn = gr.Button("返回")
submit_btn = gr.Button(" 生成分析報告", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### 第二步:AI 診斷結果")
res_title = gr.Markdown("### 點擊按鈕開始分析")
res_desc = gr.Markdown("請先完成左側問卷並點擊「生成分析報告」。")
res_prob = gr.Label(label="模型信心分佈")
gr.Markdown("---")
reset_btn = gr.Button(" 清除重新開始")
all_inputs = all_ccmq + all_osdi
btn_next.click(fn=lambda: gr.Tabs(selected=1), outputs=survey_tabs)
back_btn.click(fn=lambda: gr.Tabs(selected=0), outputs=survey_tabs)
submit_btn.click(
fn=analyze_and_predict,
inputs=all_inputs,
outputs=[res_title, res_desc, res_prob]
)
reset_btn.click(
fn=lambda: ["### 點擊按鈕開始分析", "請先完成左側問卷。", {}] + ["沒有"]*24 + ["完全不曾"]*10,
outputs=[res_title, res_desc, res_prob] + all_inputs
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)