MedicalAILabo commited on
Commit
db539f9
·
verified ·
1 Parent(s): 13a9e96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -70
app.py CHANGED
@@ -1,95 +1,157 @@
 
 
 
1
  import os
2
- import numpy as np
3
  import torch
4
- import torch.nn as nn
5
  import torch.nn.functional as F
6
- import timm
7
  import gradio as gr
 
 
8
  from PIL import Image
9
 
10
- # --- クラスラベル定義 ---
11
- PROJ_LABELS = ["AP", "PA", "Lateral"]
12
- ROT_LABELS = ["Upright", "Inverted", "Left90", "Right90"]
13
-
14
- # --- モデル定義: EfficientNet backbone + 2-head ---
15
- class CXRModel(nn.Module):
16
- def __init__(self):
17
- super().__init__()
18
- # 1ch 入力, 分類器を持たない backbone
19
- self.backbone = timm.create_model(
20
- "efficientnet_b0", pretrained=False, in_chans=1, num_classes=0
21
- )
22
- nf = self.backbone.num_features
23
- # 2つの出力ヘッド
24
- self.projection_head = nn.Linear(nf, len(PROJ_LABELS))
25
- self.rotation_head = nn.Linear(nf, len(ROT_LABELS))
26
-
27
- def forward(self, x):
28
- # 特徴量抽出
29
- feats = self.backbone.forward_features(x)
30
- # global pool → flatten
31
- pooled = self.backbone.global_pool(feats).flatten(1)
32
- # 2ヘッド出力
33
- return self.projection_head(pooled), self.rotation_head(pooled)
34
-
35
- # --- モデル・重みロード ---
36
- model = CXRModel()
37
- state_dict = torch.load("cxp_projection_rotation.pt", map_location="cpu")
38
- model.load_state_dict(state_dict) # OrderedDict → モデルに流し込む
39
- model.eval() # 評価モード
40
-
41
- # --- 推論 & HTML 結果生成関数 ---
42
- def predict(image_path: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  img = Image.open(image_path).convert("L")
44
- arr = np.array(img, dtype=np.float32) / 255.0
45
- tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) # [1,1,256,256]
46
 
47
  with torch.no_grad():
48
- proj_logits, rot_logits = model(tensor)
49
-
50
- # softmax で確率化
51
- proj_probs = F.softmax(proj_logits, dim=1)[0].cpu().numpy()
52
- rot_probs = F.softmax(rot_logits, dim=1)[0].cpu().numpy()
53
 
54
- # ベストラベル+確率
55
- pi = int(np.argmax(proj_probs)); ri = int(np.argmax(rot_probs))
56
- pl, rl = PROJ_LABELS[pi], ROT_LABELS[ri]
57
- pp, rp = proj_probs[pi], rot_probs[ri]
 
58
 
59
- # 元ラベル推定(ファイル名から)
60
  base = os.path.splitext(os.path.basename(image_path))[0]
61
  try:
62
- orig_p, orig_r = base.split("_", 1)
63
  except ValueError:
64
- orig_p = orig_r = None
65
 
66
- # 警告生成
67
  warn_html = ""
68
- if orig_p and orig_p != pl:
69
  warn_html += "<p style='color:red'>⚠ Potential mislabeled projection</p>"
70
- if orig_r and orig_r != rl:
71
  warn_html += "<p style='color:red'>⚠ Potential mislabeled rotation</p>"
72
 
73
- return (
74
- f"<p><strong>Projection :</strong> {pl} (p={pp:.3f})</p>"
75
- f"<p><strong>Rotation :</strong> {rl} (p={rp:.3f})</p>"
 
76
  f"{warn_html}"
77
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # --- Gradio UI ---
80
- with gr.Blocks() as demo:
81
  with gr.Row():
82
- with gr.Column():
83
- img_in = gr.Image(label="Upload PNG (256×256)", type="filepath", tool=None)
84
- samples = sorted(
85
- os.path.join("sample_images", f)
86
- for f in os.listdir("sample_images") if f.lower().endswith(".png")
87
- )
88
- gr.Examples(examples=samples, inputs=img_in, label="Sample Images")
89
- with gr.Column():
90
- output = gr.HTML()
91
-
92
- img_in.change(fn=predict, inputs=img_in, outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  if __name__ == "__main__":
95
- demo.launch()
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
  import os
 
5
  import torch
 
6
  import torch.nn.functional as F
 
7
  import gradio as gr
8
+ import numpy as np
9
+ import torchvision.transforms as T
10
  from PIL import Image
11
 
12
+ from lib.framework import create_model
13
+ from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
14
+ from lib.dataloader import ImageMixin
15
+
16
+ # ===========================================
17
+ # 1) パス設定
18
+ # ===========================================
19
+ WEIGHT_PATH = "./weight_epoch-011_best.pt"
20
+ PARAMETER_JSON = "./parameters.json"
21
+
22
+ # ===========================================
23
+ # 2) クラスラベル定義
24
+ # ===========================================
25
+ LABEL_APorPA = ["AP", "PA", "Lateral"]
26
+ LABEL_ROUND = ["Upright", "Inverted", "Left rotation", "Right rotation"]
27
+
28
+ # ===========================================
29
+ # 3) 前処理クラス
30
+ # ===========================================
31
+ class ImageHandler(ImageMixin):
32
+ def __init__(self, params):
33
+ self.params = params
34
+ self.transform = T.Compose([
35
+ # 256×256 前提なら Resize は不要
36
+ # T.Resize((256, 256)),
37
+ T.ToTensor(),
38
+ ])
39
+
40
+ def set_image(self, image: Image.Image):
41
+ tensor = self.transform(image) # [C,H,W], float32 in [0,1]
42
+ return {"image": tensor.unsqueeze(0)} # バッチ次元追加
43
+
44
+ # ===========================================
45
+ # 4) パラメータロード
46
+ # ===========================================
47
+ def load_parameter(parameter_path):
48
+ _args = ParamSet()
49
+ params = _retrieve_parameter(parameter_path)
50
+ for k, v in params.items():
51
+ setattr(_args, k, v)
52
+ # 推論用に上書き
53
+ _args.augmentation = "no"
54
+ _args.sampler = "no"
55
+ _args.pretrained = False
56
+ _args.mlp = None
57
+ _args.net = _args.model
58
+ _args.device = torch.device("cpu")
59
+ return (
60
+ _dispatch_by_group(_args, "model"),
61
+ _dispatch_by_group(_args, "dataloader"),
62
+ )
63
+
64
+ args_model, args_dataloader = load_parameter(PARAMETER_JSON)
65
+
66
+ # ===========================================
67
+ # 5) モデル作成&重みロード
68
+ # ===========================================
69
+ model = create_model(args_model)
70
+ print(f"Loading weights from {WEIGHT_PATH}")
71
+ model.load_weight(WEIGHT_PATH)
72
+ model.eval()
73
+
74
+ # ===========================================
75
+ # 6) 推論+HTML生成
76
+ # ===========================================
77
+ def predict_html(image_path: str) -> str:
78
+ # 画像読み込み
79
  img = Image.open(image_path).convert("L")
80
+ handler = ImageHandler(args_dataloader)
81
+ batch = handler.set_image(img)
82
 
83
  with torch.no_grad():
84
+ outputs = model(batch)
85
+ # raw logits
86
+ logits_proj = outputs.get("label_APorPA")
87
+ logits_rot = outputs.get("label_round")
 
88
 
89
+ # argmax でラベル選択
90
+ idx_proj = int(torch.argmax(logits_proj, dim=1).item())
91
+ idx_rot = int(torch.argmax(logits_rot, dim=1).item())
92
+ pred_proj = LABEL_APorPA[idx_proj]
93
+ pred_rot = LABEL_ROUND[idx_rot]
94
 
95
+ # ファイ���名から元ラベル取得(例: "AP_Upright.png")
96
  base = os.path.splitext(os.path.basename(image_path))[0]
97
  try:
98
+ orig_proj, orig_rot = base.split("_", 1)
99
  except ValueError:
100
+ orig_proj = orig_rot = None
101
 
102
+ # 警告HTML
103
  warn_html = ""
104
+ if orig_proj and orig_proj != pred_proj:
105
  warn_html += "<p style='color:red'>⚠ Potential mislabeled projection</p>"
106
+ if orig_rot and orig_rot != pred_rot:
107
  warn_html += "<p style='color:red'>⚠ Potential mislabeled rotation</p>"
108
 
109
+ # 結果表示用HTML
110
+ html = (
111
+ f"<p><strong>Projection :</strong> {pred_proj}</p>"
112
+ f"<p><strong>Rotation :</strong> {pred_rot}</p>"
113
  f"{warn_html}"
114
  )
115
+ return html
116
+
117
+ # ===========================================
118
+ # 7) Gradio UI
119
+ # ===========================================
120
+ html_header = """
121
+ <div style="padding:10px;border:1px solid #ddd;border-radius:5px">
122
+ <h2>Chest X‑ray Projection & Rotation Classification</h2>
123
+ <p>Upload a 256×256 grayscale PNG. The model predicts projection (AP/PA/Lateral)
124
+ and rotation (Upright/Inverted/Left/Right) and warns if filename label differs.</p>
125
+ </div>
126
+ """
127
+
128
+ with gr.Blocks(title="CXR Projection & Rotation") as demo:
129
+ gr.HTML(html_header)
130
 
 
 
131
  with gr.Row():
132
+ input_image = gr.Image(
133
+ label="Upload PNG (256×256)",
134
+ type="filepath",
135
+ image_mode="L"
136
+ )
137
+ output_html = gr.HTML()
138
+
139
+ send_btn = gr.Button("Run Inference")
140
+ send_btn.click(
141
+ fn=predict_html,
142
+ inputs=input_image,
143
+ outputs=output_html
144
+ )
145
+
146
+ # サンプル例
147
+ gr.Examples(
148
+ examples=[
149
+ "./sample/sample_AP_Upright.png",
150
+ "./sample/sample_PA_Inverted.png",
151
+ "./sample/sample_Lateral_Right rotation.png",
152
+ ],
153
+ inputs=input_image
154
+ )
155
 
156
  if __name__ == "__main__":
157
+ demo.launch(debug=True)