| | |
| | |
| |
|
| | import torch |
| | import gradio as gr |
| | import numpy as np |
| | import torchvision.transforms as T |
| |
|
| | from lib.framework import create_model |
| | from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group |
| | from lib.dataloader import ImageMixin |
| |
|
| | |
| | |
| | |
| | test_weight = './weight_epoch-011_best.pt' |
| | parameter = './parameters.json' |
| |
|
| | |
| | |
| | |
| | |
| | |
| | LABEL_APorPA = [ |
| | "AP", |
| | "PA", |
| | "Lateral", |
| | ] |
| |
|
| | LABEL_ROUND = [ |
| | "Upright", |
| | "Inverted", |
| | "Left rotation", |
| | "Right rotation" |
| | ] |
| |
|
| | |
| | |
| | |
| | |
| | class ImageHandler(ImageMixin): |
| | def __init__(self, params): |
| | self.params = params |
| | self.transform = T.Compose([ |
| | |
| | T.ToTensor(), |
| | ]) |
| |
|
| | def set_image(self, image): |
| | image = self.transform(image) |
| | image = {'image': image.unsqueeze(0)} |
| | return image |
| |
|
| | |
| | |
| | |
| | def load_parameter(parameter): |
| | _args = ParamSet() |
| | params = _retrieve_parameter(parameter) |
| | for _param, _arg in params.items(): |
| | setattr(_args, _param, _arg) |
| |
|
| | |
| | _args.augmentation = 'no' |
| | _args.sampler = 'no' |
| | _args.pretrained = False |
| | _args.mlp = None |
| | _args.net = _args.model |
| | _args.device = torch.device('cpu') |
| |
|
| | args_model = _dispatch_by_group(_args, 'model') |
| | args_dataloader = _dispatch_by_group(_args, 'dataloader') |
| | return args_model, args_dataloader |
| |
|
| | args_model, args_dataloader = load_parameter(parameter) |
| |
|
| | |
| | |
| | |
| | model = create_model(args_model) |
| | print(f"Load weight: {test_weight}") |
| | model.load_weight(test_weight) |
| | model.eval() |
| |
|
| | |
| | |
| | |
| | def classify_APorPA_and_round(image): |
| | """ |
| | モデルが以下を出力する想定: |
| | outputs["label_APorPA"] -> shape=[1, 3] (3クラス: AP, PA, Lateral) |
| | outputs["label_round"] -> shape=[1, 4] (4クラス: Upright, Inverted, Left rotation, Right rotation) |
| | """ |
| | image_handler = ImageHandler(args_dataloader) |
| | image_tensor = image_handler.set_image(image) |
| |
|
| | with torch.no_grad(): |
| | outputs = model(image_tensor) |
| |
|
| | |
| | print("keys in outputs =", outputs.keys()) |
| | if "label_APorPA" in outputs: |
| | print("label_APorPA shape =", outputs["label_APorPA"].shape) |
| | if "label_round" in outputs: |
| | print("label_round shape =", outputs["label_round"].shape) |
| |
|
| | |
| | if "label_APorPA" not in outputs: |
| | print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}") |
| | return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'" |
| |
|
| | scores_APorPA = outputs["label_APorPA"] |
| | pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item() |
| | predicted_APorPA = LABEL_APorPA[pred_APorPA_idx] |
| |
|
| | |
| | if "label_round" not in outputs: |
| | print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}") |
| | return predicted_APorPA, "ERROR: Missing 'label_round'" |
| |
|
| | scores_round = outputs["label_round"] |
| | pred_round_idx = torch.argmax(scores_round, dim=1).item() |
| | predicted_round = LABEL_ROUND[pred_round_idx] |
| |
|
| | return predicted_APorPA, predicted_round |
| |
|
| | |
| | |
| | |
| | html_content = """ |
| | <div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;"> |
| | <h3>Chest X-ray: AP/PA & Rotation Classification</h3> |
| | <p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p> |
| | <p>胸部レントゲン画像に対して、撮像方向(3クラス: AP, PA, Lateral) と |
| | 回転方向(4クラス: Upright, Inverted, Left rotation, Right rotation)を同時に推定します。</p> |
| | </div> |
| | """ |
| |
|
| | with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo: |
| | gr.HTML("<div style='text-align:center'><h2>Chest X-ray AP/PA & Rotation Classification</h2></div>") |
| | gr.HTML(html_content) |
| |
|
| | with gr.Row(): |
| | input_image = gr.Image(type="pil", image_mode="L") |
| | output_APorPA = gr.Label(label="Predicted AP/PA/Lateral") |
| | output_round = gr.Label(label="Predicted Rotation") |
| |
|
| | send_btn = gr.Button("Inference") |
| | send_btn.click( |
| | fn=classify_APorPA_and_round, |
| | inputs=input_image, |
| | outputs=[output_APorPA, output_round] |
| | ) |
| |
|
| | with gr.Row(): |
| | |
| | gr.Examples( |
| | examples=[ |
| | './sample/sample_AP_inverted.png', |
| | './sample/sample_PA_right.png', |
| | './sample/sample_lateral_upright.png' |
| | ], |
| | inputs=input_image |
| | ) |
| |
|
| | demo.launch(debug=True) |
| |
|