MedicalAILabo commited on
Commit
982fbd3
·
verified ·
1 Parent(s): 58487cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -56
app.py CHANGED
@@ -1,93 +1,95 @@
1
  import os
2
  import numpy as np
3
  import torch
 
4
  import torch.nn.functional as F
 
5
  import gradio as gr
6
  from PIL import Image
7
 
8
- # PyTorch モデルをロード
9
- # map_location="cpu" で CPU 上にロード
10
- model = torch.load("cxp_projection_rotation.pt", map_location="cpu")
11
- model.eval() # 評価モードに切り替え
12
-
13
- # クラスラベル定義
14
  PROJ_LABELS = ["AP", "PA", "Lateral"]
15
  ROT_LABELS = ["Upright", "Inverted", "Left90", "Right90"]
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def predict(image_path: str) -> str:
18
- # 画像をグレースケールで読み込み(Lモード)
19
  img = Image.open(image_path).convert("L")
20
- arr = np.array(img, dtype=np.float32) / 255.0 # 0-1 正規化
21
- # バッチ&チャンネル次元を追加して Tensor 化
22
  tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) # [1,1,256,256]
23
 
24
- # 推論
25
  with torch.no_grad():
26
- proj_logits, rot_logits = model(tensor) # 2 ヘッド出力
27
 
28
- # ソフトマックスで確率化
29
  proj_probs = F.softmax(proj_logits, dim=1)[0].cpu().numpy()
30
  rot_probs = F.softmax(rot_logits, dim=1)[0].cpu().numpy()
31
 
32
- # 最も確率の高いラベルと確率
33
- proj_idx = int(np.argmax(proj_probs))
34
- rot_idx = int(np.argmax(rot_probs))
35
- proj_lbl = PROJ_LABELS[proj_idx]
36
- rot_lbl = ROT_LABELS[rot_idx]
37
- proj_p = proj_probs[proj_idx]
38
- rot_p = rot_probs[rot_idx]
39
 
40
- # ファイル名から元ラベルを取得(例: "AP_Upright.png" → ["AP","Upright"])
41
  base = os.path.splitext(os.path.basename(image_path))[0]
42
  try:
43
- orig_proj, orig_rot = base.split("_", 1)
44
  except ValueError:
45
- orig_proj = orig_rot = None
46
 
47
- # 警告メッセージ(必要なら赤字で表示)
48
- warnings_html = ""
49
- if orig_proj and orig_proj != proj_lbl:
50
- warnings_html += "<p style='color:red'>⚠ Potential mislabeled projection</p>"
51
- if orig_rot and orig_rot != rot_lbl:
52
- warnings_html += "<p style='color:red'>⚠ Potential mislabeled rotation</p>"
53
 
54
- # 結果を HTML 形式で返す
55
- html = (
56
- f"<p><strong>Projection :</strong> {proj_lbl} (p={proj_p:.3f})</p>"
57
- f"<p><strong>Rotation :</strong> {rot_lbl} (p={rot_p:.3f})</p>"
58
- f"{warnings_html}"
59
  )
60
- return html
61
 
62
- # Gradio UI 定義
63
  with gr.Blocks() as demo:
64
  with gr.Row():
65
  with gr.Column():
66
- image_input = gr.Image(
67
- label="Upload PNG (256×256)",
68
- type="filepath",
69
- tool=None
70
- )
71
- # sample_images フォルダからサンプル画像を読み込む
72
- sample_list = sorted([
73
  os.path.join("sample_images", f)
74
- for f in os.listdir("sample_images")
75
- if f.lower().endswith(".png")
76
- ])
77
- gr.Examples(
78
- examples=sample_list,
79
- inputs=image_input,
80
- label="Sample Images"
81
  )
 
82
  with gr.Column():
83
- result = gr.HTML()
84
 
85
- # 画像を選択/アップロードしたら自動で predict を実行
86
- image_input.change(
87
- fn=predict,
88
- inputs=image_input,
89
- outputs=result
90
- )
91
 
92
  if __name__ == "__main__":
93
  demo.launch()
 
1
  import os
2
  import numpy as np
3
  import torch
4
+ import torch.nn as nn
5
  import torch.nn.functional as F
6
+ import timm
7
  import gradio as gr
8
  from PIL import Image
9
 
10
+ # --- クラスラベル定義 ---
 
 
 
 
 
11
  PROJ_LABELS = ["AP", "PA", "Lateral"]
12
  ROT_LABELS = ["Upright", "Inverted", "Left90", "Right90"]
13
 
14
+ # --- モデル定義: EfficientNet backbone + 2-head ---
15
+ class CXRModel(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+ # 1ch 入力, 分類器を持たない backbone
19
+ self.backbone = timm.create_model(
20
+ "efficientnet_b0", pretrained=False, in_chans=1, num_classes=0
21
+ )
22
+ nf = self.backbone.num_features
23
+ # 2つの出力ヘッド
24
+ self.projection_head = nn.Linear(nf, len(PROJ_LABELS))
25
+ self.rotation_head = nn.Linear(nf, len(ROT_LABELS))
26
+
27
+ def forward(self, x):
28
+ # 特徴量抽出
29
+ feats = self.backbone.forward_features(x)
30
+ # global pool → flatten
31
+ pooled = self.backbone.global_pool(feats).flatten(1)
32
+ # 2ヘッド出力
33
+ return self.projection_head(pooled), self.rotation_head(pooled)
34
+
35
+ # --- モデル・重みロード ---
36
+ model = CXRModel()
37
+ state_dict = torch.load("cxp_projection_rotation.pt", map_location="cpu")
38
+ model.load_state_dict(state_dict) # OrderedDict → モデルに流し込む
39
+ model.eval() # 評価モード
40
+
41
+ # --- 推論 & HTML 結果生成関数 ---
42
  def predict(image_path: str) -> str:
 
43
  img = Image.open(image_path).convert("L")
44
+ arr = np.array(img, dtype=np.float32) / 255.0
 
45
  tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) # [1,1,256,256]
46
 
 
47
  with torch.no_grad():
48
+ proj_logits, rot_logits = model(tensor)
49
 
50
+ # softmax で確率化
51
  proj_probs = F.softmax(proj_logits, dim=1)[0].cpu().numpy()
52
  rot_probs = F.softmax(rot_logits, dim=1)[0].cpu().numpy()
53
 
54
+ # ベストラベル+確率
55
+ pi = int(np.argmax(proj_probs)); ri = int(np.argmax(rot_probs))
56
+ pl, rl = PROJ_LABELS[pi], ROT_LABELS[ri]
57
+ pp, rp = proj_probs[pi], rot_probs[ri]
 
 
 
58
 
59
+ # 元ラベル推定(ファイル名から)
60
  base = os.path.splitext(os.path.basename(image_path))[0]
61
  try:
62
+ orig_p, orig_r = base.split("_", 1)
63
  except ValueError:
64
+ orig_p = orig_r = None
65
 
66
+ # 警告生成
67
+ warn_html = ""
68
+ if orig_p and orig_p != pl:
69
+ warn_html += "<p style='color:red'>⚠ Potential mislabeled projection</p>"
70
+ if orig_r and orig_r != rl:
71
+ warn_html += "<p style='color:red'>⚠ Potential mislabeled rotation</p>"
72
 
73
+ return (
74
+ f"<p><strong>Projection :</strong> {pl} (p={pp:.3f})</p>"
75
+ f"<p><strong>Rotation :</strong> {rl} (p={rp:.3f})</p>"
76
+ f"{warn_html}"
 
77
  )
 
78
 
79
+ # --- Gradio UI ---
80
  with gr.Blocks() as demo:
81
  with gr.Row():
82
  with gr.Column():
83
+ img_in = gr.Image(label="Upload PNG (256×256)", type="filepath", tool=None)
84
+ samples = sorted(
 
 
 
 
 
85
  os.path.join("sample_images", f)
86
+ for f in os.listdir("sample_images") if f.lower().endswith(".png")
 
 
 
 
 
 
87
  )
88
+ gr.Examples(examples=samples, inputs=img_in, label="Sample Images")
89
  with gr.Column():
90
+ output = gr.HTML()
91
 
92
+ img_in.change(fn=predict, inputs=img_in, outputs=output)
 
 
 
 
 
93
 
94
  if __name__ == "__main__":
95
  demo.launch()