File size: 5,953 Bytes
4bb0fa4 80c8afb 4bb0fa4 ad18100 4bb0fa4 ad18100 4bb0fa4 ad18100 4bb0fa4 ad18100 4bb0fa4 ad18100 5113ba3 4bb0fa4 ad18100 4bb0fa4 5113ba3 4bb0fa4 ad18100 4bb0fa4 ad18100 4bb0fa4 968c6d5 a660e42 4bb0fa4 968c6d5 543a456 4bb0fa4 ad18100 4bb0fa4 d47fc30 4bb0fa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import gradio as gr
import numpy as np
import torchvision.transforms as T
from lib.framework import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin
# ===========================================
# 1) パスなど(修正があれば適宜変更)
# ===========================================
test_weight = './weight_epoch-011_best.pt'
parameter = './parameters.json'
# ===========================================
# 2) クラスラベルの定義
# - label_APorPA (3クラス): 0=AP, 1=PA, 2=Lateral
# - label_round (4クラス): 0=Upright, 1=Inverted, 2=Left rotation, 3=Right rotation
# ===========================================
LABEL_APorPA = [
"AP", # class 0
"PA", # class 1
"Lateral", # class 2
]
LABEL_ROUND = [
"Upright", # class 0
"Inverted", # class 1
"Left rotation", # class 2
"Right rotation" # class 3
]
# ===========================================
# 3) 前処理用の ImageHandlerクラス
# - 画像が既に256×256前提
# ===========================================
class ImageHandler(ImageMixin):
def __init__(self, params):
self.params = params
self.transform = T.Compose([
# T.Resize((256, 256)), # 必要であればコメントアウトを外す
T.ToTensor(),
])
def set_image(self, image):
image = self.transform(image)
image = {'image': image.unsqueeze(0)}
return image
# ===========================================
# 4) パラメータのロード
# ===========================================
def load_parameter(parameter):
_args = ParamSet()
params = _retrieve_parameter(parameter)
for _param, _arg in params.items():
setattr(_args, _param, _arg)
# 推論用に書き換え (学習関連は無効化または無視)
_args.augmentation = 'no'
_args.sampler = 'no'
_args.pretrained = False
_args.mlp = None
_args.net = _args.model
_args.device = torch.device('cpu')
args_model = _dispatch_by_group(_args, 'model')
args_dataloader = _dispatch_by_group(_args, 'dataloader')
return args_model, args_dataloader
args_model, args_dataloader = load_parameter(parameter)
# ===========================================
# 5) モデルを作成し学習済み重みをロード
# ===========================================
model = create_model(args_model)
print(f"Load weight: {test_weight}")
model.load_weight(test_weight)
model.eval() # 推論モード
# ===========================================
# 6) 推論関数
# ===========================================
def classify_APorPA_and_round(image):
"""
モデルが以下を出力する想定:
outputs["label_APorPA"] -> shape=[1, 3] (3クラス: AP, PA, Lateral)
outputs["label_round"] -> shape=[1, 4] (4クラス: Upright, Inverted, Left rotation, Right rotation)
"""
image_handler = ImageHandler(args_dataloader)
image_tensor = image_handler.set_image(image)
with torch.no_grad():
outputs = model(image_tensor)
# デバッグ用の出力チェック
print("keys in outputs =", outputs.keys())
if "label_APorPA" in outputs:
print("label_APorPA shape =", outputs["label_APorPA"].shape)
if "label_round" in outputs:
print("label_round shape =", outputs["label_round"].shape)
# --- label_APorPA ---
if "label_APorPA" not in outputs:
print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
scores_APorPA = outputs["label_APorPA"] # shape=[1,3]
pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
# --- label_round ---
if "label_round" not in outputs:
print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}")
return predicted_APorPA, "ERROR: Missing 'label_round'"
scores_round = outputs["label_round"] # shape=[1,4]
pred_round_idx = torch.argmax(scores_round, dim=1).item()
predicted_round = LABEL_ROUND[pred_round_idx]
return predicted_APorPA, predicted_round
# ===========================================
# 7) Gradio UI
# ===========================================
html_content = """
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
<h3>Chest X-ray: Projection & Rotation Classification</h3>
<p>The input image should be a 256×256 (grayscale) PNG file.</p>
<p>This model predict both the imaging projection (3 classes: AP, PA, Lateral) and rotation (4 classes: Upright, Inverted, Left rotation, Right rotation) for chest radiographs.</p>
</div>
"""
with gr.Blocks(title="Chest X-ray: Projection & Rotation Classification") as demo:
gr.HTML("<div style='text-align:center'><h2>Chest X-ray Projection & Rotation Classification</h2></div>")
gr.HTML(html_content)
with gr.Row():
input_image = gr.Image(type="pil", image_mode="L")
output_APorPA = gr.Label(label="Predicted AP/PA/Lateral")
output_round = gr.Label(label="Predicted Rotation")
send_btn = gr.Button("Inference")
send_btn.click(
fn=classify_APorPA_and_round,
inputs=input_image,
outputs=[output_APorPA, output_round]
)
with gr.Row():
# サンプルファイルは実際のパスに置き換えてください
gr.Examples(
examples=[
'./sample/sample_AP_inverted.png',
'./sample/sample_PA_right.png',
'./sample/sample_lateral_upright.png'
],
inputs=input_image
)
demo.launch(debug=True)
|