MedicalAILabo commited on
Commit
321fa37
·
verified ·
1 Parent(s): f9cd9f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -15
app.py CHANGED
@@ -16,7 +16,7 @@ from lib.dataloader import ImageMixin
16
  # ===========================================
17
  # 1) パス設定
18
  # ===========================================
19
- WEIGHT_PATH = "./cxp_projection_rotation.pt"
20
  PARAMETER_JSON = "./parameters.json"
21
 
22
  # ===========================================
@@ -32,8 +32,6 @@ class ImageHandler(ImageMixin):
32
  def __init__(self, params):
33
  self.params = params
34
  self.transform = T.Compose([
35
- # 256×256 前提なら Resize は不要
36
- # T.Resize((256, 256)),
37
  T.ToTensor(),
38
  ])
39
 
@@ -82,15 +80,20 @@ def predict_html(image_path: str) -> str:
82
 
83
  with torch.no_grad():
84
  outputs = model(batch)
85
- # raw logits
86
  logits_proj = outputs.get("label_APorPA")
87
  logits_rot = outputs.get("label_round")
88
 
 
 
 
 
89
  # argmax でラベル選択
90
- idx_proj = int(torch.argmax(logits_proj, dim=1).item())
91
- idx_rot = int(torch.argmax(logits_rot, dim=1).item())
92
  pred_proj = LABEL_APorPA[idx_proj]
93
  pred_rot = LABEL_ROUND[idx_rot]
 
 
94
 
95
  # ファイル名から元ラベル取得(例: "1_AP_Upright.png" → orig_proj="AP", orig_rot="Upright")
96
  base = os.path.splitext(os.path.basename(image_path))[0]
@@ -101,17 +104,48 @@ def predict_html(image_path: str) -> str:
101
  else:
102
  orig_proj = orig_rot = None
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  # 警告HTML
105
  warn_html = ""
106
- if orig_proj and orig_proj != pred_proj:
107
- warn_html += "<p style='color:red'>⚠ Potential mislabeled projection</p>"
108
- if orig_rot and orig_rot != pred_rot:
109
- warn_html += "<p style='color:red'>⚠ Potential mislabeled rotation</p>"
 
 
 
 
 
 
110
 
111
  # 結果表示用HTML
112
  html = (
113
- f"<p><strong>Projection :</strong> {pred_proj}</p>"
114
- f"<p><strong>Rotation :</strong> {pred_rot}</p>"
 
 
115
  f"{warn_html}"
116
  )
117
  return html
@@ -123,7 +157,8 @@ html_header = """
123
  <div style="padding:10px;border:1px solid #ddd;border-radius:5px">
124
  <h2>Chest X‑ray Projection & Rotation Classification</h2>
125
  <p>Upload a 256×256 grayscale PNG. The model predicts projection (AP/PA/Lateral)
126
- and rotation (Upright/Inverted/Left/Right) and warns if filename label differs.</p>
 
127
  </div>
128
  """
129
 
@@ -158,7 +193,7 @@ with gr.Blocks(title="CXR Projection & Rotation") as demo:
158
 
159
  # サンプルのファイル名を一覧で表示
160
  gr.Markdown(
161
- "**Sample filenames:** \n"
162
  "- 1_AP_Upright.png \n"
163
  "- 1_PA_Inverted.png \n"
164
  "- 2_AP_Right90.png \n"
@@ -166,4 +201,4 @@ with gr.Blocks(title="CXR Projection & Rotation") as demo:
166
  )
167
 
168
  if __name__ == "__main__":
169
- demo.launch(debug=True)
 
16
  # ===========================================
17
  # 1) パス設定
18
  # ===========================================
19
+ WEIGHT_PATH = "./cxp_projection_rotation.pt"
20
  PARAMETER_JSON = "./parameters.json"
21
 
22
  # ===========================================
 
32
  def __init__(self, params):
33
  self.params = params
34
  self.transform = T.Compose([
 
 
35
  T.ToTensor(),
36
  ])
37
 
 
80
 
81
  with torch.no_grad():
82
  outputs = model(batch)
 
83
  logits_proj = outputs.get("label_APorPA")
84
  logits_rot = outputs.get("label_round")
85
 
86
+ # softmax で確率に変換
87
+ probs_proj = F.softmax(logits_proj, dim=1)[0].cpu().numpy()
88
+ probs_rot = F.softmax(logits_rot, dim=1)[0].cpu().numpy()
89
+
90
  # argmax でラベル選択
91
+ idx_proj = int(probs_proj.argmax())
92
+ idx_rot = int(probs_rot.argmax())
93
  pred_proj = LABEL_APorPA[idx_proj]
94
  pred_rot = LABEL_ROUND[idx_rot]
95
+ conf_proj = float(probs_proj[idx_proj])
96
+ conf_rot = float(probs_rot[idx_rot])
97
 
98
  # ファイル名から元ラベル取得(例: "1_AP_Upright.png" → orig_proj="AP", orig_rot="Upright")
99
  base = os.path.splitext(os.path.basename(image_path))[0]
 
104
  else:
105
  orig_proj = orig_rot = None
106
 
107
+ # 警告HTML作成用ヘルパー
108
+ def make_warning(kind, orig, pred, conf):
109
+ # kind: "projection" or "rotation"
110
+ high_thr = 0.8
111
+ med_thr = 0.5
112
+ if orig and orig != pred:
113
+ if conf >= high_thr:
114
+ return (
115
+ f"<p style='color:red'>⚠ Potentially mislabeled {kind}: "
116
+ f"filename says {orig}, model predicts {pred} (confidence {conf:.2f})</p>"
117
+ )
118
+ elif conf >= med_thr:
119
+ return (
120
+ f"<p style='color:orange'>⚠ There is a possibility of mislabeled {kind}: "
121
+ f"model predicts {pred} with moderate confidence ({conf:.2f})</p>"
122
+ )
123
+ if conf < med_thr:
124
+ return (
125
+ f"<p style='color:orange'>⚠ Low confidence for {kind} ({conf:.2f}); "
126
+ f"please check image quality or framing.</p>"
127
+ )
128
+ return ""
129
+
130
  # 警告HTML
131
  warn_html = ""
132
+ warn_html += make_warning("projection", orig_proj, pred_proj, conf_proj)
133
+ warn_html += make_warning("rotation", orig_rot, pred_rot, conf_rot)
134
+
135
+ # クラスごとのスコア表示用HTML
136
+ scores_proj = ", ".join(
137
+ f"{LABEL_APorPA[i]}: {p:.2f}" for i, p in enumerate(probs_proj)
138
+ )
139
+ scores_rot = ", ".join(
140
+ f"{LABEL_ROUND[i]}: {p:.2f}" for i, p in enumerate(probs_rot)
141
+ )
142
 
143
  # 結果表示用HTML
144
  html = (
145
+ f"<p><strong>Projection :</strong> {pred_proj} "
146
+ f"<small>({scores_proj})</small></p>"
147
+ f"<p><strong>Rotation :</strong> {pred_rot} "
148
+ f"<small>({scores_rot})</small></p>"
149
  f"{warn_html}"
150
  )
151
  return html
 
157
  <div style="padding:10px;border:1px solid #ddd;border-radius:5px">
158
  <h2>Chest X‑ray Projection & Rotation Classification</h2>
159
  <p>Upload a 256×256 grayscale PNG. The model predicts projection (AP/PA/Lateral)
160
+ and rotation (Upright/Inverted/Left/Right) and shows softmax confidences.
161
+ It warns if filename label differs or if confidence is low.</p>
162
  </div>
163
  """
164
 
 
193
 
194
  # サンプルのファイル名を一覧で表示
195
  gr.Markdown(
196
+ "**Sample filenames:** 𝚮\n"
197
  "- 1_AP_Upright.png \n"
198
  "- 1_PA_Inverted.png \n"
199
  "- 2_AP_Right90.png \n"
 
201
  )
202
 
203
  if __name__ == "__main__":
204
+ demo.launch(debug=True)