|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torchvision.transforms as T |
|
|
|
|
|
from lib.framework import create_model |
|
|
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group |
|
|
from lib.dataloader import ImageMixin |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_weight = './weight_epoch-011_best.pt' |
|
|
parameter = './parameters.json' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LABEL_APorPA = [ |
|
|
"AP", |
|
|
"PA", |
|
|
"Lateral", |
|
|
] |
|
|
|
|
|
LABEL_ROUND = [ |
|
|
"Upright", |
|
|
"Inverted", |
|
|
"Left rotation", |
|
|
"Right rotation" |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageHandler(ImageMixin): |
|
|
def __init__(self, params): |
|
|
self.params = params |
|
|
self.transform = T.Compose([ |
|
|
|
|
|
T.ToTensor(), |
|
|
]) |
|
|
|
|
|
def set_image(self, image): |
|
|
image = self.transform(image) |
|
|
image = {'image': image.unsqueeze(0)} |
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_parameter(parameter): |
|
|
_args = ParamSet() |
|
|
params = _retrieve_parameter(parameter) |
|
|
for _param, _arg in params.items(): |
|
|
setattr(_args, _param, _arg) |
|
|
|
|
|
|
|
|
_args.augmentation = 'no' |
|
|
_args.sampler = 'no' |
|
|
_args.pretrained = False |
|
|
_args.mlp = None |
|
|
_args.net = _args.model |
|
|
_args.device = torch.device('cpu') |
|
|
|
|
|
args_model = _dispatch_by_group(_args, 'model') |
|
|
args_dataloader = _dispatch_by_group(_args, 'dataloader') |
|
|
return args_model, args_dataloader |
|
|
|
|
|
args_model, args_dataloader = load_parameter(parameter) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = create_model(args_model) |
|
|
print(f"Load weight: {test_weight}") |
|
|
model.load_weight(test_weight) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_APorPA_and_round(image): |
|
|
""" |
|
|
モデルが以下を出力する想定: |
|
|
outputs["label_APorPA"] -> shape=[1, 3] (3クラス: AP, PA, Lateral) |
|
|
outputs["label_round"] -> shape=[1, 4] (4クラス: Upright, Inverted, Left rotation, Right rotation) |
|
|
""" |
|
|
image_handler = ImageHandler(args_dataloader) |
|
|
image_tensor = image_handler.set_image(image) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(image_tensor) |
|
|
|
|
|
|
|
|
print("keys in outputs =", outputs.keys()) |
|
|
if "label_APorPA" in outputs: |
|
|
print("label_APorPA shape =", outputs["label_APorPA"].shape) |
|
|
if "label_round" in outputs: |
|
|
print("label_round shape =", outputs["label_round"].shape) |
|
|
|
|
|
|
|
|
if "label_APorPA" not in outputs: |
|
|
print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}") |
|
|
return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'" |
|
|
|
|
|
scores_APorPA = outputs["label_APorPA"] |
|
|
pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item() |
|
|
predicted_APorPA = LABEL_APorPA[pred_APorPA_idx] |
|
|
|
|
|
|
|
|
if "label_round" not in outputs: |
|
|
print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}") |
|
|
return predicted_APorPA, "ERROR: Missing 'label_round'" |
|
|
|
|
|
scores_round = outputs["label_round"] |
|
|
pred_round_idx = torch.argmax(scores_round, dim=1).item() |
|
|
predicted_round = LABEL_ROUND[pred_round_idx] |
|
|
|
|
|
return predicted_APorPA, predicted_round |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
html_content = """ |
|
|
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;"> |
|
|
<h3>Chest X-ray: Projection & Rotation Classification</h3> |
|
|
<p>The input image should be a 256×256 (grayscale) PNG file.</p> |
|
|
<p>This model predict both the imaging projection (3 classes: AP, PA, Lateral) and rotation (4 classes: Upright, Inverted, Left rotation, Right rotation) for chest radiographs.</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="Chest X-ray: Projection & Rotation Classification") as demo: |
|
|
gr.HTML("<div style='text-align:center'><h2>Chest X-ray Projection & Rotation Classification</h2></div>") |
|
|
gr.HTML(html_content) |
|
|
|
|
|
with gr.Row(): |
|
|
input_image = gr.Image(type="pil", image_mode="L") |
|
|
output_APorPA = gr.Label(label="Predicted AP/PA/Lateral") |
|
|
output_round = gr.Label(label="Predicted Rotation") |
|
|
|
|
|
send_btn = gr.Button("Inference") |
|
|
send_btn.click( |
|
|
fn=classify_APorPA_and_round, |
|
|
inputs=input_image, |
|
|
outputs=[output_APorPA, output_round] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
'./sample/sample_AP_inverted.png', |
|
|
'./sample/sample_PA_right.png', |
|
|
'./sample/sample_lateral_upright.png' |
|
|
], |
|
|
inputs=input_image |
|
|
) |
|
|
|
|
|
demo.launch(debug=True) |
|
|
|