MedicalAILabo commited on
Commit
bfb9d80
·
verified ·
1 Parent(s): a91a8a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -58
app.py CHANGED
@@ -1,80 +1,169 @@
 
 
 
1
  import os
2
- import json
3
- from PIL import Image
4
  import torch
5
- import torchvision.transforms as T
6
  import gradio as gr
 
 
 
 
 
 
 
7
 
8
- # Import your model creation function
9
- from model import create_model # ここで create_model を定義しているファイルを指定してください
 
 
 
10
 
11
- # Load parameters and model
12
- with open("parameters.json", "r") as f:
13
- parameters = json.load(f)
 
 
14
 
15
- model = create_model(parameters)
16
- weights = torch.load("cxp_projection_rotation.pt", map_location="cpu")
17
- model.load_state_dict(weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  model.eval()
19
 
20
- # Transformation for grayscale images
21
- transform = T.Compose([
22
- T.ToTensor(), # converts [H,W] to [1,H,W]
23
- ])
 
 
 
 
24
 
25
- # Prediction and HTML rendering
26
- def predict_html(image_path):
27
- # Preprocess and infer
28
- img = Image.open(image_path)
29
- x = transform(img).unsqueeze(0)
30
  with torch.no_grad():
31
- proj_logits, rot_logits = model(x)
32
- proj_idx = proj_logits.argmax(dim=1).item()
33
- rot_idx = rot_logits.argmax(dim=1).item()
34
- proj_pred = parameters["projection_labels"][proj_idx]
35
- rot_pred = parameters["rotation_labels"][rot_idx]
36
-
37
- # Parse file name: ID_Projection_Rotation.png
38
- filename = os.path.basename(image_path)
39
- name, _ = os.path.splitext(filename)
40
- parts = name.split("_")
 
 
 
 
41
  if len(parts) >= 3:
42
  orig_proj = parts[1]
43
- orig_rot = parts[2]
44
- orig_label = f"{orig_proj}_{orig_rot}"
45
  else:
46
- orig_label = None
47
-
48
- # Build HTML output
49
- html = f"<h3>Prediction: {proj_pred} / {rot_pred}</h3>"
50
- if orig_label:
51
- if orig_label != f"{proj_pred}_{rot_pred}":
52
- html += f"<p style='color:red;'>Warning: original label '<strong>{orig_label}</strong>' does not match prediction.</p>"
53
- else:
54
- html += f"<p>Original label '<strong>{orig_label}</strong>' matches prediction.</p>"
 
 
 
 
 
 
55
  return html
56
 
57
- # Gradio UI
58
- with gr.Blocks() as demo:
59
- gr.Markdown("## Chest X-ray Projection and Rotation Classifier")
 
 
 
 
 
 
 
 
 
 
 
60
  with gr.Row():
61
- img_input = gr.Image(type="filepath", image_mode="L", label="Upload PNG Image (256×256)")
62
- html_output = gr.HTML(label="Result")
 
 
 
 
63
 
64
- classify_btn = gr.Button("Classify Image")
65
- classify_btn.click(fn=predict_html, inputs=img_input, outputs=html_output)
 
 
 
 
66
 
67
- # Sample images with filenames shown
68
- sample_dir = "samples"
69
- sample_files = sorted([f for f in os.listdir(sample_dir) if f.endswith('.png')])
70
- sample_paths = [os.path.join(sample_dir, f) for f in sample_files]
71
  gr.Examples(
72
- examples=sample_paths,
73
- inputs=img_input,
74
- outputs=html_output,
75
- fn=predict_html,
76
- label="Sample Images"
 
 
 
 
 
 
 
 
 
 
 
77
  )
78
 
79
  if __name__ == "__main__":
80
- demo.launch(debug=True)
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
  import os
 
 
5
  import torch
6
+ import torch.nn.functional as F
7
  import gradio as gr
8
+ import numpy as np
9
+ import torchvision.transforms as T
10
+ from PIL import Image
11
+
12
+ from lib.framework import create_model
13
+ from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
14
+ from lib.dataloader import ImageMixin
15
 
16
+ # ===========================================
17
+ # 1) パス設定
18
+ # ===========================================
19
+ WEIGHT_PATH = "./cxp_projection_rotation.pt"
20
+ PARAMETER_JSON = "./parameters.json"
21
 
22
+ # ===========================================
23
+ # 2) クラスラベル定義
24
+ # ===========================================
25
+ LABEL_APorPA = ["AP", "PA", "Lateral"]
26
+ LABEL_ROUND = ["Upright", "Inverted", "Left rotation", "Right rotation"]
27
 
28
+ # ===========================================
29
+ # 3) 前処理クラス
30
+ # ===========================================
31
+ class ImageHandler(ImageMixin):
32
+ def __init__(self, params):
33
+ self.params = params
34
+ self.transform = T.Compose([
35
+ # 256×256 前提なら Resize は不要
36
+ # T.Resize((256, 256)),
37
+ T.ToTensor(),
38
+ ])
39
+
40
+ def set_image(self, image: Image.Image):
41
+ tensor = self.transform(image) # [C,H,W], float32 in [0,1]
42
+ return {"image": tensor.unsqueeze(0)} # バッチ次元追加
43
+
44
+ # ===========================================
45
+ # 4) パラメータロード
46
+ # ===========================================
47
+ def load_parameter(parameter_path):
48
+ _args = ParamSet()
49
+ params = _retrieve_parameter(parameter_path)
50
+ for k, v in params.items():
51
+ setattr(_args, k, v)
52
+ # 推論用に上書き
53
+ _args.augmentation = "no"
54
+ _args.sampler = "no"
55
+ _args.pretrained = False
56
+ _args.mlp = None
57
+ _args.net = _args.model
58
+ _args.device = torch.device("cpu")
59
+ return (
60
+ _dispatch_by_group(_args, "model"),
61
+ _dispatch_by_group(_args, "dataloader"),
62
+ )
63
+
64
+ args_model, args_dataloader = load_parameter(PARAMETER_JSON)
65
+
66
+ # ===========================================
67
+ # 5) モデル作成&重みロード
68
+ # ===========================================
69
+ model = create_model(args_model)
70
+ print(f"Loading weights from {WEIGHT_PATH}")
71
+ model.load_weight(WEIGHT_PATH)
72
  model.eval()
73
 
74
+ # ===========================================
75
+ # 6) 推論+HTML生成
76
+ # ===========================================
77
+ def predict_html(image_path: str) -> str:
78
+ # 画像読み込み
79
+ img = Image.open(image_path).convert("L")
80
+ handler = ImageHandler(args_dataloader)
81
+ batch = handler.set_image(img)
82
 
 
 
 
 
 
83
  with torch.no_grad():
84
+ outputs = model(batch)
85
+ # raw logits
86
+ logits_proj = outputs.get("label_APorPA")
87
+ logits_rot = outputs.get("label_round")
88
+
89
+ # argmax でラベル選択
90
+ idx_proj = int(torch.argmax(logits_proj, dim=1).item())
91
+ idx_rot = int(torch.argmax(logits_rot, dim=1).item())
92
+ pred_proj = LABEL_APorPA[idx_proj]
93
+ pred_rot = LABEL_ROUND[idx_rot]
94
+
95
+ # ファイル名から元ラベル取得(例: "1_AP_Upright.png" → orig_proj="AP", orig_rot="Upright")
96
+ base = os.path.splitext(os.path.basename(image_path))[0]
97
+ parts = base.split("_")
98
  if len(parts) >= 3:
99
  orig_proj = parts[1]
100
+ orig_rot = parts[2]
 
101
  else:
102
+ orig_proj = orig_rot = None
103
+
104
+ # 警告HTML
105
+ warn_html = ""
106
+ if orig_proj and orig_proj != pred_proj:
107
+ warn_html += "<p style='color:red'>⚠ Potential mislabeled projection</p>"
108
+ if orig_rot and orig_rot != pred_rot:
109
+ warn_html += "<p style='color:red'>⚠ Potential mislabeled rotation</p>"
110
+
111
+ # 結果表示用HTML
112
+ html = (
113
+ f"<p><strong>Projection :</strong> {pred_proj}</p>"
114
+ f"<p><strong>Rotation :</strong> {pred_rot}</p>"
115
+ f"{warn_html}"
116
+ )
117
  return html
118
 
119
+ # ===========================================
120
+ # 7) Gradio UI
121
+ # ===========================================
122
+ html_header = """
123
+ <div style="padding:10px;border:1px solid #ddd;border-radius:5px">
124
+ <h2>Chest X‑ray Projection & Rotation Classification</h2>
125
+ <p>Upload a 256×256 grayscale PNG. The model predicts projection (AP/PA/Lateral)
126
+ and rotation (Upright/Inverted/Left/Right) and warns if filename label differs.</p>
127
+ </div>
128
+ """
129
+
130
+ with gr.Blocks(title="CXR Projection & Rotation") as demo:
131
+ gr.HTML(html_header)
132
+
133
  with gr.Row():
134
+ input_image = gr.Image(
135
+ label="Upload PNG (256×256)",
136
+ type="filepath",
137
+ image_mode="L"
138
+ )
139
+ output_html = gr.HTML()
140
 
141
+ send_btn = gr.Button("Run Inference")
142
+ send_btn.click(
143
+ fn=predict_html,
144
+ inputs=input_image,
145
+ outputs=output_html
146
+ )
147
 
148
+ # サンプル例
 
 
 
149
  gr.Examples(
150
+ examples=[
151
+ "./sample/1_AP_Upright.png",
152
+ "./sample/1_PA_Inverted.png",
153
+ "./sample/2_AP_Right90.png",
154
+ "./sample/2_Lateral_Left90.png.png",
155
+ ],
156
+ inputs=input_image
157
+ )
158
+
159
+ # サンプルのファイル名を一覧で表示
160
+ gr.Markdown(
161
+ "**Sample filenames:** \n"
162
+ "- 1_AP_Upright.png \n"
163
+ "- 1_PA_Inverted.png \n"
164
+ "- 2_AP_Right90.png \n"
165
+ "- 2_Lateral_Left90.png"
166
  )
167
 
168
  if __name__ == "__main__":
169
+ demo.launch(debug=True)