File size: 7,211 Bytes
bfb9d80
 
 
22be147
58487cb
bfb9d80
6570337
bfb9d80
 
 
 
 
 
 
db539f9
bfb9d80
 
 
321fa37
bfb9d80
a91a8a9
bfb9d80
 
 
 
3b0f8d2
db539f9
bfb9d80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db539f9
 
bfb9d80
 
 
 
 
 
 
 
22be147
58487cb
bfb9d80
 
 
 
321fa37
 
 
 
bfb9d80
321fa37
 
bfb9d80
 
321fa37
 
bfb9d80
 
 
 
6570337
 
bfb9d80
6570337
bfb9d80
 
321fa37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfb9d80
 
321fa37
 
 
 
 
 
 
 
 
 
bfb9d80
 
 
321fa37
 
 
 
bfb9d80
 
db539f9
 
bfb9d80
 
 
 
 
 
 
321fa37
cfdde14
 
 
 
bfb9d80
 
 
 
 
 
22be147
bfb9d80
 
 
 
 
 
db539f9
bfb9d80
 
 
 
 
 
db539f9
bfb9d80
db539f9
bfb9d80
 
 
3b0f8d2
 
bfb9d80
 
 
 
 
 
321fa37
bfb9d80
 
 
 
db539f9
22be147
 
321fa37
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import torch
import torch.nn.functional as F
import gradio as gr
import numpy as np
import torchvision.transforms as T
from PIL import Image

from lib.framework import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin

# ===========================================
# 1) パス設定
# ===========================================
WEIGHT_PATH    = "./cxp_projection_rotation.pt"
PARAMETER_JSON = "./parameters.json"

# ===========================================
# 2) クラスラベル定義
# ===========================================
LABEL_APorPA = ["AP", "PA", "Lateral"]
LABEL_ROUND  = ["Upright", "Inverted", "Left-rotation", "Right-rotation"]

# ===========================================
# 3) 前処理クラス
# ===========================================
class ImageHandler(ImageMixin):
    def __init__(self, params):
        self.params = params
        self.transform = T.Compose([
            T.ToTensor(),
        ])

    def set_image(self, image: Image.Image):
        tensor = self.transform(image)            # [C,H,W], float32 in [0,1]
        return {"image": tensor.unsqueeze(0)}     # バッチ次元追加

# ===========================================
# 4) パラメータロード
# ===========================================
def load_parameter(parameter_path):
    _args = ParamSet()
    params = _retrieve_parameter(parameter_path)
    for k, v in params.items():
        setattr(_args, k, v)
    # 推論用に上書き
    _args.augmentation = "no"
    _args.sampler      = "no"
    _args.pretrained   = False
    _args.mlp          = None
    _args.net          = _args.model
    _args.device       = torch.device("cpu")
    return (
        _dispatch_by_group(_args, "model"),
        _dispatch_by_group(_args, "dataloader"),
    )

args_model, args_dataloader = load_parameter(PARAMETER_JSON)

# ===========================================
# 5) モデル作成&重みロード
# ===========================================
model = create_model(args_model)
print(f"Loading weights from {WEIGHT_PATH}")
model.load_weight(WEIGHT_PATH)
model.eval()

# ===========================================
# 6) 推論+HTML生成
# ===========================================
def predict_html(image_path: str) -> str:
    # 画像読み込み
    img = Image.open(image_path).convert("L")
    handler = ImageHandler(args_dataloader)
    batch = handler.set_image(img)

    with torch.no_grad():
        outputs = model(batch)
        logits_proj = outputs.get("label_APorPA")
        logits_rot  = outputs.get("label_round")

        # softmax で確率に変換
        probs_proj = F.softmax(logits_proj, dim=1)[0].cpu().numpy()
        probs_rot  = F.softmax(logits_rot,  dim=1)[0].cpu().numpy()

    # argmax でラベル選択
    idx_proj = int(probs_proj.argmax())
    idx_rot  = int(probs_rot.argmax())
    pred_proj = LABEL_APorPA[idx_proj]
    pred_rot  = LABEL_ROUND[idx_rot]
    conf_proj = float(probs_proj[idx_proj])
    conf_rot  = float(probs_rot[idx_rot])

    # ファイル名から元ラベル取得(例: "1_AP_Upright.png" → orig_proj="AP", orig_rot="Upright")
    base = os.path.splitext(os.path.basename(image_path))[0]
    parts = base.split("_")
    if len(parts) >= 3:
        orig_proj = parts[1]
        orig_rot  = parts[2]
    else:
        orig_proj = orig_rot = None

    # 警告HTML作成用ヘルパー
    def make_warning(kind, orig, pred, conf):
        # kind: "projection" or "rotation"
        high_thr = 0.8
        med_thr  = 0.5
        if orig and orig != pred:
            if conf >= high_thr:
                return (
                    f"<p style='color:red'>⚠ Potentially mislabeled {kind}: "
                    f"filename says {orig}, model predicts {pred} (confidence {conf:.2f})</p>"
                )
            elif conf >= med_thr:
                return (
                    f"<p style='color:orange'>⚠ There is a possibility of mislabeled {kind}: "
                    f"model predicts {pred} with moderate confidence ({conf:.2f})</p>"
                )
        if conf < med_thr:
            return (
                f"<p style='color:orange'>⚠ Low confidence for {kind} ({conf:.2f}); "
                f"please check image quality or framing.</p>"
            )
        return ""

    # 警告HTML
    warn_html = ""
    warn_html += make_warning("projection", orig_proj, pred_proj, conf_proj)
    warn_html += make_warning("rotation",   orig_rot,  pred_rot,  conf_rot)

    # クラスごとのスコア表示用HTML
    scores_proj = ", ".join(
        f"{LABEL_APorPA[i]}: {p:.2f}" for i, p in enumerate(probs_proj)
    )
    scores_rot = ", ".join(
        f"{LABEL_ROUND[i]}: {p:.2f}" for i, p in enumerate(probs_rot)
    )

    # 結果表示用HTML
    html = (
        f"<p><strong>Projection :</strong> {pred_proj} "
        f"<small>({scores_proj})</small></p>"
        f"<p><strong>Rotation   :</strong> {pred_rot} "
        f"<small>({scores_rot})</small></p>"
        f"{warn_html}"
    )
    return html

# ===========================================
# 7) Gradio UI
# ===========================================
html_header = """
<div style="padding:10px;border:1px solid #ddd;border-radius:5px">
  <h2>Chest X‑ray Projection & Rotation Classification</h2>
  <p>Upload a 256×256 grayscale PNG. The model predicts projection (AP/PA/Lateral)
     and rotation (Upright/Inverted/Left/Right) and shows softmax confidences. 
     It warns if filename label differs or if confidence is low. Please name the files using the format: [Number]_projection_rotation.png.
     For the projection part of the filename, please use one of the following three terms: AP/PA/Lateral
     For the rotation part of the filename, please use one of the following four terms: Upright/Inverted/Left90/Right90
     As samples, We have prepared two sets of images: A PA view in the Upright position. An AP view with Left rotation. For each set, we have created two versions, one of which includes a mislabel in either its projection or rotation tag.</p>
</div>
"""

with gr.Blocks(title="CXR Projection & Rotation") as demo:
    gr.HTML(html_header)

    with gr.Row():
        input_image = gr.Image(
            label="Upload PNG (256×256)", 
            type="filepath", 
            image_mode="L"
        )
        output_html = gr.HTML()

    send_btn = gr.Button("Run Inference")
    send_btn.click(
        fn=predict_html,
        inputs=input_image,
        outputs=output_html
    )

    # サンプル例
    gr.Examples(
        examples=[
            "./sample/1_AP_Upright.png",
            "./sample/1_PA_Inverted.png",
            "./sample/2_AP_Right-rotation.png",
            "./sample/2_Lateral_Left-rotation.png",
        ],
        inputs=input_image
    )

    # サンプルのファイル名を一覧で表示
    gr.Markdown(
        "**Sample filenames:**  𝚮\n"
        "- 1_AP_Upright.png  \n"
        "- 1_PA_Inverted.png  \n"
        "- 2_AP_Right90.png  \n"
        "- 2_Lateral_Left90.png"
    )

if __name__ == "__main__":
    demo.launch(debug=True)