Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,35 @@ import os
|
|
| 6 |
import joblib
|
| 7 |
from model import DualStreamTransformer, ArcMarginProduct
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
FOLD = 5
|
|
@@ -40,36 +69,33 @@ def analyze_and_predict(*all_answers):
|
|
| 40 |
|
| 41 |
with torch.no_grad():
|
| 42 |
feats = model(sx1, sx2)
|
|
|
|
| 43 |
logits = metric_fc.predict(feats)
|
| 44 |
probs = torch.softmax(logits, dim=1)
|
| 45 |
pred_idx = torch.argmax(probs, dim=1).item()
|
| 46 |
conf = probs[0, pred_idx].item()
|
| 47 |
|
| 48 |
-
print(f"DEBUG:
|
| 49 |
|
| 50 |
-
|
| 51 |
-
if pred_idx == 1:
|
| 52 |
-
res_label = "正常 / 健康"
|
| 53 |
-
elif pred_idx == 0:
|
| 54 |
res_label = "乾眼風險 (DES)"
|
| 55 |
else:
|
| 56 |
res_label = "修格蘭氏症風險 (SJS)"
|
| 57 |
|
| 58 |
prob_dict = {
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
"修格蘭氏 (SJS)": probs[0, 2].item()
|
| 62 |
}
|
| 63 |
-
|
| 64 |
return (
|
| 65 |
f"## 診斷結果:{res_label}",
|
| 66 |
-
f"**分析完成**:AI 信心度為 **{conf:.2%}**。本系統整合
|
| 67 |
-
|
| 68 |
)
|
| 69 |
|
| 70 |
except Exception as e:
|
| 71 |
-
|
| 72 |
-
|
| 73 |
|
| 74 |
with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 15px; border-radius: 8px; }") as demo:
|
| 75 |
gr.Markdown("# 中西醫 AI 診斷系統")
|
|
@@ -91,7 +117,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflo
|
|
| 91 |
|
| 92 |
with gr.Row():
|
| 93 |
back_btn = gr.Button("返回")
|
| 94 |
-
submit_btn = gr.Button("
|
| 95 |
|
| 96 |
|
| 97 |
with gr.Column(scale=1):
|
|
|
|
| 6 |
import joblib
|
| 7 |
from model import DualStreamTransformer, ArcMarginProduct
|
| 8 |
|
| 9 |
+
css = """
|
| 10 |
+
.scroll-box {
|
| 11 |
+
height: 300px;
|
| 12 |
+
overflow-y: auto !important;
|
| 13 |
+
overflow-x: hidden !important;
|
| 14 |
+
display: block !important;
|
| 15 |
+
width: 100% !important;
|
| 16 |
+
max-width: 100% !important;
|
| 17 |
+
}
|
| 18 |
+
.scroll-box * {
|
| 19 |
+
max-width: 100% !important;
|
| 20 |
+
box-sizing: border-box !important;
|
| 21 |
+
}
|
| 22 |
+
.vertical-radio {
|
| 23 |
+
display: block !important;
|
| 24 |
+
width: 100% !important;
|
| 25 |
+
}
|
| 26 |
+
.vertical-radio .wrap {
|
| 27 |
+
display: flex !important;
|
| 28 |
+
flex-direction: column !important;
|
| 29 |
+
width: 100% !important;
|
| 30 |
+
min-width: 0 !important;
|
| 31 |
+
}
|
| 32 |
+
.vertical-radio .gradio-radio-item {
|
| 33 |
+
width: 100% !important;
|
| 34 |
+
white-space: normal !important;
|
| 35 |
+
word-break: break-all !important;
|
| 36 |
+
}
|
| 37 |
+
"""
|
| 38 |
|
| 39 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
FOLD = 5
|
|
|
|
| 69 |
|
| 70 |
with torch.no_grad():
|
| 71 |
feats = model(sx1, sx2)
|
| 72 |
+
# 確保 metric_fc 初始化時是 ArcMarginProduct(32, 2)
|
| 73 |
logits = metric_fc.predict(feats)
|
| 74 |
probs = torch.softmax(logits, dim=1)
|
| 75 |
pred_idx = torch.argmax(probs, dim=1).item()
|
| 76 |
conf = probs[0, pred_idx].item()
|
| 77 |
|
| 78 |
+
print(f"DEBUG: 推論成功! 索引: {pred_idx}, 信心度: {conf}")
|
| 79 |
|
| 80 |
+
if pred_idx == 0:
|
|
|
|
|
|
|
|
|
|
| 81 |
res_label = "乾眼風險 (DES)"
|
| 82 |
else:
|
| 83 |
res_label = "修格蘭氏症風險 (SJS)"
|
| 84 |
|
| 85 |
prob_dict = {
|
| 86 |
+
"乾眼 (DES)": probs[0, 0].item(),
|
| 87 |
+
"修格蘭氏 (SJS)": probs[0, 1].item()
|
|
|
|
| 88 |
}
|
| 89 |
+
|
| 90 |
return (
|
| 91 |
f"## 診斷結果:{res_label}",
|
| 92 |
+
f"**分析完成**:AI 信心度為 **{conf:.2%}**。本系統已整合 CCMQ 體質與 OSDI 症狀進行二分類篩檢。",
|
| 93 |
+
prob_dict
|
| 94 |
)
|
| 95 |
|
| 96 |
except Exception as e:
|
| 97 |
+
print(f"計算出錯: {e}")
|
| 98 |
+
return "### 計算出錯", f"錯誤原因: {str(e)}", {}
|
| 99 |
|
| 100 |
with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 15px; border-radius: 8px; }") as demo:
|
| 101 |
gr.Markdown("# 中西醫 AI 診斷系統")
|
|
|
|
| 117 |
|
| 118 |
with gr.Row():
|
| 119 |
back_btn = gr.Button("返回")
|
| 120 |
+
submit_btn = gr.Button(" 生成分析報告", variant="primary")
|
| 121 |
|
| 122 |
|
| 123 |
with gr.Column(scale=1):
|