MedicalAILabo commited on
Commit
6570337
·
verified ·
1 Parent(s): 4abbce7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -139
app.py CHANGED
@@ -1,157 +1,76 @@
1
- #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
-
4
  import os
 
 
5
  import torch
6
- import torch.nn.functional as F
7
- import gradio as gr
8
- import numpy as np
9
  import torchvision.transforms as T
10
- from PIL import Image
11
-
12
- from lib.framework import create_model
13
- from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
14
- from lib.dataloader import ImageMixin
15
-
16
- # ===========================================
17
- # 1) パス設定
18
- # ===========================================
19
- WEIGHT_PATH = "./cxp_projection_rotation.pt"
20
- PARAMETER_JSON = "./parameters.json"
21
-
22
- # ===========================================
23
- # 2) クラスラベル定義
24
- # ===========================================
25
- LABEL_APorPA = ["AP", "PA", "Lateral"]
26
- LABEL_ROUND = ["Upright", "Inverted", "Left rotation", "Right rotation"]
27
-
28
- # ===========================================
29
- # 3) 前処理クラス
30
- # ===========================================
31
- class ImageHandler(ImageMixin):
32
- def __init__(self, params):
33
- self.params = params
34
- self.transform = T.Compose([
35
- # 256×256 前提なら Resize は不要
36
- # T.Resize((256, 256)),
37
- T.ToTensor(),
38
- ])
39
-
40
- def set_image(self, image: Image.Image):
41
- tensor = self.transform(image) # [C,H,W], float32 in [0,1]
42
- return {"image": tensor.unsqueeze(0)} # バッチ次元追加
43
-
44
- # ===========================================
45
- # 4) パラメータロード
46
- # ===========================================
47
- def load_parameter(parameter_path):
48
- _args = ParamSet()
49
- params = _retrieve_parameter(parameter_path)
50
- for k, v in params.items():
51
- setattr(_args, k, v)
52
- # 推論用に上書き
53
- _args.augmentation = "no"
54
- _args.sampler = "no"
55
- _args.pretrained = False
56
- _args.mlp = None
57
- _args.net = _args.model
58
- _args.device = torch.device("cpu")
59
- return (
60
- _dispatch_by_group(_args, "model"),
61
- _dispatch_by_group(_args, "dataloader"),
62
- )
63
 
64
- args_model, args_dataloader = load_parameter(PARAMETER_JSON)
 
 
65
 
66
- # ===========================================
67
- # 5) モデル作成&重みロード
68
- # ===========================================
69
- model = create_model(args_model)
70
- print(f"Loading weights from {WEIGHT_PATH}")
71
- model.load_weight(WEIGHT_PATH)
72
  model.eval()
73
 
74
- # ===========================================
75
- # 6) 推論+HTML生成
76
- # ===========================================
77
- def predict_html(image_path: str) -> str:
78
- # 画像読み込み
79
- img = Image.open(image_path).convert("L")
80
- handler = ImageHandler(args_dataloader)
81
- batch = handler.set_image(img)
82
 
 
 
 
 
 
83
  with torch.no_grad():
84
- outputs = model(batch)
85
- # raw logits
86
- logits_proj = outputs.get("label_APorPA")
87
- logits_rot = outputs.get("label_round")
88
-
89
- # argmax でラベル選択
90
- idx_proj = int(torch.argmax(logits_proj, dim=1).item())
91
- idx_rot = int(torch.argmax(logits_rot, dim=1).item())
92
- pred_proj = LABEL_APorPA[idx_proj]
93
- pred_rot = LABEL_ROUND[idx_rot]
94
-
95
- # ファイル名から元ラベル取得(例: "AP_Upright.png")
96
- base = os.path.splitext(os.path.basename(image_path))[0]
97
- try:
98
- orig_proj, orig_rot = base.split("_", 1)
99
- except ValueError:
100
- orig_proj = orig_rot = None
101
-
102
- # 警告HTML
103
- warn_html = ""
104
- if orig_proj and orig_proj != pred_proj:
105
- warn_html += "<p style='color:red'>⚠ Potential mislabeled projection</p>"
106
- if orig_rot and orig_rot != pred_rot:
107
- warn_html += "<p style='color:red'>⚠ Potential mislabeled rotation</p>"
108
-
109
- # 結果表示用HTML
110
- html = (
111
- f"<p><strong>Projection :</strong> {pred_proj}</p>"
112
- f"<p><strong>Rotation :</strong> {pred_rot}</p>"
113
- f"{warn_html}"
114
- )
115
  return html
116
 
117
- # ===========================================
118
- # 7) Gradio UI
119
- # ===========================================
120
- html_header = """
121
- <div style="padding:10px;border:1px solid #ddd;border-radius:5px">
122
- <h2>Chest X‑ray Projection & Rotation Classification</h2>
123
- <p>Upload a 256×256 grayscale PNG. The model predicts projection (AP/PA/Lateral)
124
- and rotation (Upright/Inverted/Left/Right) and warns if filename label differs.</p>
125
- </div>
126
- """
127
-
128
- with gr.Blocks(title="CXR Projection & Rotation") as demo:
129
- gr.HTML(html_header)
130
-
131
  with gr.Row():
132
- input_image = gr.Image(
133
- label="Upload PNG (256×256)",
134
- type="filepath",
135
- image_mode="L"
136
- )
137
- output_html = gr.HTML()
138
 
139
- send_btn = gr.Button("Run Inference")
140
- send_btn.click(
141
- fn=predict_html,
142
- inputs=input_image,
143
- outputs=output_html
144
- )
145
 
146
- # サンプル例
 
 
 
147
  gr.Examples(
148
- examples=[
149
- "./sample/1_AP_Upright.png",
150
- "./sample/1_PA_Inverted.png",
151
- "./sample/2_AP_Right90.png",
152
- "./sample/2_Lateral_Left90.png.png",
153
- ],
154
- inputs=input_image
155
  )
156
 
157
  if __name__ == "__main__":
 
 
 
 
1
  import os
2
+ import json
3
+ from PIL import Image
4
  import torch
 
 
 
5
  import torchvision.transforms as T
6
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Load parameters and model
9
+ with open("parameters.json", "r") as f:
10
+ parameters = json.load(f)
11
 
12
+ model = create_model(parameters) # your existing create_model function
13
+ weights = torch.load("cxp_projection_rotation.pt", map_location="cpu")
14
+ model.load_state_dict(weights)
 
 
 
15
  model.eval()
16
 
17
+ # Transformation for grayscale images
18
+ transform = T.Compose([
19
+ T.ToTensor(), # converts [H,W] to [1,H,W]
20
+ ])
 
 
 
 
21
 
22
+ # Prediction and HTML rendering
23
+ def predict_html(image_path):
24
+ # Preprocess and infer
25
+ img = Image.open(image_path)
26
+ x = transform(img).unsqueeze(0)
27
  with torch.no_grad():
28
+ proj_logits, rot_logits = model(x)
29
+ proj_idx = proj_logits.argmax(dim=1).item()
30
+ rot_idx = rot_logits.argmax(dim=1).item()
31
+ proj_pred = parameters["projection_labels"][proj_idx]
32
+ rot_pred = parameters["rotation_labels"][rot_idx]
33
+
34
+ # Parse file name: ID_Projection_Rotation.png
35
+ filename = os.path.basename(image_path)
36
+ name, _ = os.path.splitext(filename)
37
+ parts = name.split("_")
38
+ if len(parts) >= 3:
39
+ orig_proj = parts[1]
40
+ orig_rot = parts[2]
41
+ orig_label = f"{orig_proj}_{orig_rot}"
42
+ else:
43
+ orig_label = None
44
+
45
+ # Build HTML output
46
+ html = f"<h3>Prediction: {proj_pred} / {rot_pred}</h3>"
47
+ if orig_label:
48
+ if orig_label != f"{proj_pred}_{rot_pred}":
49
+ html += f"<p style='color:red;'>Warning: original label '<strong>{orig_label}</strong>' does not match prediction.</p>"
50
+ else:
51
+ html += f"<p>Original label '<strong>{orig_label}</strong>' matches prediction.</p>"
 
 
 
 
 
 
 
52
  return html
53
 
54
+ # Gradio UI
55
+ with gr.Blocks() as demo:
56
+ gr.Markdown("## Chest X-ray Projection and Rotation Classifier")
 
 
 
 
 
 
 
 
 
 
 
57
  with gr.Row():
58
+ img_input = gr.Image(type="filepath", image_mode="L", label="Upload PNG Image (256×256)")
59
+ html_output = gr.HTML(label="Result")
 
 
 
 
60
 
61
+ classify_btn = gr.Button("Classify Image")
62
+ classify_btn.click(fn=predict_html, inputs=img_input, outputs=html_output)
 
 
 
 
63
 
64
+ # Sample images with filenames shown
65
+ sample_dir = "samples"
66
+ sample_files = sorted([f for f in os.listdir(sample_dir) if f.endswith('.png')])
67
+ sample_paths = [os.path.join(sample_dir, f) for f in sample_files]
68
  gr.Examples(
69
+ examples=sample_paths,
70
+ inputs=img_input,
71
+ outputs=html_output,
72
+ fn=predict_html,
73
+ label="Sample Images"
 
 
74
  )
75
 
76
  if __name__ == "__main__":