| |
| |
|
|
| import torch |
| import gradio as gr |
| import numpy as np |
| import torchvision.transforms as T |
|
|
| from lib.framework import create_model |
| from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group |
| from lib.dataloader import ImageMixin |
|
|
| |
| |
| |
| test_weight = './weight_epoch-011_best.pt' |
| parameter = './parameters.json' |
|
|
| |
| |
| |
| |
| LABEL_APorPA = [ |
| "AP", |
| "PA" |
| ] |
|
|
| LABEL_ROUND = [ |
| "0° Rotation", |
| "90° Rotation", |
| "180° Rotation", |
| "270° Rotation" |
| ] |
|
|
| |
| |
| |
| |
| class ImageHandler(ImageMixin): |
| def __init__(self, params): |
| self.params = params |
| |
| |
| self.transform = T.Compose([ |
| |
| T.ToTensor(), |
| ]) |
|
|
| def set_image(self, image): |
| |
| image = self.transform(image) |
| image = {'image': image.unsqueeze(0)} |
| return image |
|
|
| |
| |
| |
| def load_parameter(parameter): |
| _args = ParamSet() |
| params = _retrieve_parameter(parameter) |
| for _param, _arg in params.items(): |
| setattr(_args, _param, _arg) |
|
|
| |
| _args.augmentation = 'no' |
| _args.sampler = 'no' |
| _args.pretrained = False |
| _args.mlp = None |
| _args.net = _args.model |
| _args.device = torch.device('cpu') |
|
|
| args_model = _dispatch_by_group(_args, 'model') |
| args_dataloader = _dispatch_by_group(_args, 'dataloader') |
| return args_model, args_dataloader |
|
|
| args_model, args_dataloader = load_parameter(parameter) |
|
|
| |
| |
| |
| model = create_model(args_model) |
| print(f"Load weight: {test_weight}") |
| model.load_weight(test_weight) |
| model.eval() |
|
|
| |
| |
| |
| def classify_APorPA_and_round(image): |
| """ |
| モデルが以下を出力する想定: |
| outputs["label_APorPA"] -> shape=[1, 2] (2クラス: AP/PA) |
| outputs["label_round"] -> shape=[1, 4] (4クラス: 0°, 90°, 180°, 270°) |
| """ |
| image_handler = ImageHandler(args_dataloader) |
| image_tensor = image_handler.set_image(image) |
|
|
| with torch.no_grad(): |
| outputs = model(image_tensor) |
|
|
| |
| if "label_APorPA" not in outputs: |
| print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}") |
| return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'" |
|
|
| scores_APorPA = outputs["label_APorPA"] |
| pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item() |
| predicted_APorPA = LABEL_APorPA[pred_APorPA_idx] |
|
|
| |
| if "label_round" not in outputs: |
| print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}") |
| return predicted_APorPA, "ERROR: Missing 'label_round'" |
|
|
| scores_round = outputs["label_round"] |
| pred_round_idx = torch.argmax(scores_round, dim=1).item() |
| predicted_round = LABEL_ROUND[pred_round_idx] |
|
|
| return predicted_APorPA, predicted_round |
|
|
| |
| |
| |
| html_content = """ |
| <div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;"> |
| <h3>Chest X-ray: AP/PA & Rotation Classification</h3> |
| <p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p> |
| <p>胸部レントゲン画像に対して、撮像方向(AP or PA)と回転方向(0°, 90°, 180°, 270°)を同時に推定します。</p> |
| </div> |
| """ |
|
|
| with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo: |
| gr.HTML("<div style='text-align:center'><h2>Chest X-ray AP/PA & Rotation Classification</h2></div>") |
| gr.HTML(html_content) |
|
|
| with gr.Row(): |
| input_image = gr.Image(type="pil", image_mode="L") |
| output_APorPA = gr.Label(label="Predicted AP or PA") |
| output_round = gr.Label(label="Predicted Rotation") |
|
|
| send_btn = gr.Button("Inference") |
| send_btn.click( |
| fn=classify_APorPA_and_round, |
| inputs=input_image, |
| outputs=[output_APorPA, output_round] |
| ) |
|
|
| with gr.Row(): |
| |
| gr.Examples( |
| examples=[ |
| './sample/sample_AP_inverted.png', |
| './sample/sample_PA_right.png', |
| './sample/sample_lateral_upright.png' |
| ], |
| inputs=input_image |
| ) |
|
|
| demo.launch(debug=True) |
|
|