PinHsuan commited on
Commit
39c14aa
·
verified ·
1 Parent(s): d95eecf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -14
app.py CHANGED
@@ -6,6 +6,35 @@ import os
6
  import joblib
7
  from model import DualStreamTransformer, ArcMarginProduct
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  FOLD = 5
@@ -40,36 +69,33 @@ def analyze_and_predict(*all_answers):
40
 
41
  with torch.no_grad():
42
  feats = model(sx1, sx2)
 
43
  logits = metric_fc.predict(feats)
44
  probs = torch.softmax(logits, dim=1)
45
  pred_idx = torch.argmax(probs, dim=1).item()
46
  conf = probs[0, pred_idx].item()
47
 
48
- print(f"DEBUG: Prediction successful! Pred: {pred_idx}")
49
 
50
-
51
- if pred_idx == 1:
52
- res_label = "正常 / 健康"
53
- elif pred_idx == 0:
54
  res_label = "乾眼風險 (DES)"
55
  else:
56
  res_label = "修格蘭氏症風險 (SJS)"
57
 
58
  prob_dict = {
59
- "健康": probs[0, 0].item(),
60
- "乾眼 (DES)": probs[0, 1].item(),
61
- "修格蘭氏 (SJS)": probs[0, 2].item()
62
  }
63
-
64
  return (
65
  f"## 診斷結果:{res_label}",
66
- f"**分析完成**:AI 信心度為 **{conf:.2%}**。本系統整合了中醫 24 體質特徵西醫 10 項 OSDI 症狀進行多模態計算。",
67
- {"Risk": conf if pred_idx==1 else 1-conf, "Healthy": 1 - (conf if pred_idx==1 else 1-conf)}
68
  )
69
 
70
  except Exception as e:
71
- print(f"Error: {e}")
72
- return f"### 計算出錯", str(e), {}
73
 
74
  with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 15px; border-radius: 8px; }") as demo:
75
  gr.Markdown("# 中西醫 AI 診斷系統")
@@ -91,7 +117,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflo
91
 
92
  with gr.Row():
93
  back_btn = gr.Button("返回")
94
- submit_btn = gr.Button("🚀 生成分析報告", variant="primary")
95
 
96
 
97
  with gr.Column(scale=1):
 
6
  import joblib
7
  from model import DualStreamTransformer, ArcMarginProduct
8
 
9
+ css = """
10
+ .scroll-box {
11
+ height: 300px;
12
+ overflow-y: auto !important;
13
+ overflow-x: hidden !important;
14
+ display: block !important;
15
+ width: 100% !important;
16
+ max-width: 100% !important;
17
+ }
18
+ .scroll-box * {
19
+ max-width: 100% !important;
20
+ box-sizing: border-box !important;
21
+ }
22
+ .vertical-radio {
23
+ display: block !important;
24
+ width: 100% !important;
25
+ }
26
+ .vertical-radio .wrap {
27
+ display: flex !important;
28
+ flex-direction: column !important;
29
+ width: 100% !important;
30
+ min-width: 0 !important;
31
+ }
32
+ .vertical-radio .gradio-radio-item {
33
+ width: 100% !important;
34
+ white-space: normal !important;
35
+ word-break: break-all !important;
36
+ }
37
+ """
38
 
39
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  FOLD = 5
 
69
 
70
  with torch.no_grad():
71
  feats = model(sx1, sx2)
72
+ # 確保 metric_fc 初始化時是 ArcMarginProduct(32, 2)
73
  logits = metric_fc.predict(feats)
74
  probs = torch.softmax(logits, dim=1)
75
  pred_idx = torch.argmax(probs, dim=1).item()
76
  conf = probs[0, pred_idx].item()
77
 
78
+ print(f"DEBUG: 推論成功! 索引: {pred_idx}, 信心度: {conf}")
79
 
80
+ if pred_idx == 0:
 
 
 
81
  res_label = "乾眼風險 (DES)"
82
  else:
83
  res_label = "修格蘭氏症風險 (SJS)"
84
 
85
  prob_dict = {
86
+ "乾眼 (DES)": probs[0, 0].item(),
87
+ "修格蘭氏 (SJS)": probs[0, 1].item()
 
88
  }
89
+
90
  return (
91
  f"## 診斷結果:{res_label}",
92
+ f"**分析完成**:AI 信心度為 **{conf:.2%}**。本系統整合 CCMQ 體質與 OSDI 症狀進行二分類篩檢。",
93
+ prob_dict
94
  )
95
 
96
  except Exception as e:
97
+ print(f"計算出錯: {e}")
98
+ return "### 計算出錯", f"錯誤原因: {str(e)}", {}
99
 
100
  with gr.Blocks(theme=gr.themes.Soft(), css=".scroll-box { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 15px; border-radius: 8px; }") as demo:
101
  gr.Markdown("# 中西醫 AI 診斷系統")
 
117
 
118
  with gr.Row():
119
  back_btn = gr.Button("返回")
120
+ submit_btn = gr.Button(" 生成分析報告", variant="primary")
121
 
122
 
123
  with gr.Column(scale=1):