File size: 5,953 Bytes
4bb0fa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80c8afb
4bb0fa4
 
 
 
ad18100
 
4bb0fa4
 
ad18100
 
 
4bb0fa4
 
 
ad18100
 
 
 
4bb0fa4
 
 
 
ad18100
4bb0fa4
 
 
 
 
ad18100
5113ba3
4bb0fa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad18100
 
4bb0fa4
 
 
 
 
 
 
5113ba3
 
 
 
 
 
 
4bb0fa4
 
 
 
 
ad18100
4bb0fa4
 
 
 
 
 
 
 
ad18100
4bb0fa4
 
 
 
 
 
 
 
 
 
968c6d5
a660e42
 
4bb0fa4
 
 
968c6d5
543a456
4bb0fa4
 
 
 
ad18100
4bb0fa4
 
 
 
 
 
 
 
 
 
 
 
 
d47fc30
 
 
4bb0fa4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import gradio as gr
import numpy as np
import torchvision.transforms as T

from lib.framework import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin

# ===========================================
# 1) パスなど(修正があれば適宜変更)
# ===========================================
test_weight = './weight_epoch-011_best.pt'
parameter = './parameters.json'

# ===========================================
# 2) クラスラベルの定義
#    - label_APorPA (3クラス): 0=AP, 1=PA, 2=Lateral
#    - label_round (4クラス): 0=Upright, 1=Inverted, 2=Left rotation, 3=Right rotation
# ===========================================
LABEL_APorPA = [
    "AP",       # class 0
    "PA",       # class 1
    "Lateral",  # class 2
]

LABEL_ROUND = [
    "Upright",         # class 0
    "Inverted",        # class 1
    "Left rotation",   # class 2
    "Right rotation"   # class 3
]

# ===========================================
# 3) 前処理用の ImageHandlerクラス
#    - 画像が既に256×256前提
# ===========================================
class ImageHandler(ImageMixin):
    def __init__(self, params):
        self.params = params
        self.transform = T.Compose([
            # T.Resize((256, 256)),  # 必要であればコメントアウトを外す
            T.ToTensor(),
        ])

    def set_image(self, image):
        image = self.transform(image)
        image = {'image': image.unsqueeze(0)}
        return image

# ===========================================
# 4) パラメータのロード
# ===========================================
def load_parameter(parameter):
    _args = ParamSet()
    params = _retrieve_parameter(parameter)
    for _param, _arg in params.items():
        setattr(_args, _param, _arg)

    # 推論用に書き換え (学習関連は無効化または無視)
    _args.augmentation = 'no'
    _args.sampler = 'no'
    _args.pretrained = False
    _args.mlp = None
    _args.net = _args.model
    _args.device = torch.device('cpu')

    args_model = _dispatch_by_group(_args, 'model')
    args_dataloader = _dispatch_by_group(_args, 'dataloader')
    return args_model, args_dataloader

args_model, args_dataloader = load_parameter(parameter)

# ===========================================
# 5) モデルを作成し学習済み重みをロード
# ===========================================
model = create_model(args_model)
print(f"Load weight: {test_weight}")
model.load_weight(test_weight)
model.eval()  # 推論モード

# ===========================================
# 6) 推論関数
# ===========================================
def classify_APorPA_and_round(image):
    """
    モデルが以下を出力する想定:
      outputs["label_APorPA"]  -> shape=[1, 3] (3クラス: AP, PA, Lateral)
      outputs["label_round"]   -> shape=[1, 4] (4クラス: Upright, Inverted, Left rotation, Right rotation)
    """
    image_handler = ImageHandler(args_dataloader)
    image_tensor = image_handler.set_image(image)

    with torch.no_grad():
        outputs = model(image_tensor)

        # デバッグ用の出力チェック
        print("keys in outputs =", outputs.keys())
        if "label_APorPA" in outputs:
            print("label_APorPA shape =", outputs["label_APorPA"].shape)
        if "label_round" in outputs:
            print("label_round shape =", outputs["label_round"].shape)

        # --- label_APorPA ---
        if "label_APorPA" not in outputs:
            print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
            return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"

        scores_APorPA = outputs["label_APorPA"]  # shape=[1,3]
        pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
        predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]

        # --- label_round ---
        if "label_round" not in outputs:
            print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}")
            return predicted_APorPA, "ERROR: Missing 'label_round'"

        scores_round = outputs["label_round"]  # shape=[1,4]
        pred_round_idx = torch.argmax(scores_round, dim=1).item()
        predicted_round = LABEL_ROUND[pred_round_idx]

    return predicted_APorPA, predicted_round

# ===========================================
# 7) Gradio UI
# ===========================================
html_content = """
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
    <h3>Chest X-ray: Projection & Rotation Classification</h3>
    <p>The input image should be a 256×256 (grayscale) PNG file.</p>
    <p>This model predict both the imaging projection (3 classes: AP, PA, Lateral) and rotation (4 classes: Upright, Inverted, Left rotation, Right rotation) for chest radiographs.</p>
</div>
"""

with gr.Blocks(title="Chest X-ray: Projection & Rotation Classification") as demo:
    gr.HTML("<div style='text-align:center'><h2>Chest X-ray Projection & Rotation Classification</h2></div>")
    gr.HTML(html_content)

    with gr.Row():
        input_image = gr.Image(type="pil", image_mode="L")
        output_APorPA = gr.Label(label="Predicted AP/PA/Lateral")  
        output_round = gr.Label(label="Predicted Rotation")

    send_btn = gr.Button("Inference")
    send_btn.click(
        fn=classify_APorPA_and_round, 
        inputs=input_image, 
        outputs=[output_APorPA, output_round]
    )

    with gr.Row():
        # サンプルファイルは実際のパスに置き換えてください
        gr.Examples(
            examples=[
                './sample/sample_AP_inverted.png',
                './sample/sample_PA_right.png',
                './sample/sample_lateral_upright.png'
            ],
            inputs=input_image
        )

demo.launch(debug=True)