#!/usr/bin/env python # -*- coding: utf-8 -*- import torch import gradio as gr import numpy as np import torchvision.transforms as T from lib.framework import create_model from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group from lib.dataloader import ImageMixin # =========================================== # 1) パスなど(修正があれば適宜変更) # =========================================== test_weight = './weight_epoch-011_best.pt' parameter = './parameters.json' # =========================================== # 2) クラスラベルの定義 # - label_APorPA (3クラス): 0=AP, 1=PA, 2=Lateral # - label_round (4クラス): 0=Upright, 1=Inverted, 2=Left rotation, 3=Right rotation # =========================================== LABEL_APorPA = [ "AP", # class 0 "PA", # class 1 "Lateral", # class 2 ] LABEL_ROUND = [ "Upright", # class 0 "Inverted", # class 1 "Left rotation", # class 2 "Right rotation" # class 3 ] # =========================================== # 3) 前処理用の ImageHandlerクラス # - 画像が既に256×256前提 # =========================================== class ImageHandler(ImageMixin): def __init__(self, params): self.params = params self.transform = T.Compose([ # T.Resize((256, 256)), # 必要であればコメントアウトを外す T.ToTensor(), ]) def set_image(self, image): image = self.transform(image) image = {'image': image.unsqueeze(0)} return image # =========================================== # 4) パラメータのロード # =========================================== def load_parameter(parameter): _args = ParamSet() params = _retrieve_parameter(parameter) for _param, _arg in params.items(): setattr(_args, _param, _arg) # 推論用に書き換え (学習関連は無効化または無視) _args.augmentation = 'no' _args.sampler = 'no' _args.pretrained = False _args.mlp = None _args.net = _args.model _args.device = torch.device('cpu') args_model = _dispatch_by_group(_args, 'model') args_dataloader = _dispatch_by_group(_args, 'dataloader') return args_model, args_dataloader args_model, args_dataloader = load_parameter(parameter) # =========================================== # 5) モデルを作成し学習済み重みをロード # =========================================== model = create_model(args_model) print(f"Load weight: {test_weight}") model.load_weight(test_weight) model.eval() # 推論モード # =========================================== # 6) 推論関数 # =========================================== def classify_APorPA_and_round(image): """ モデルが以下を出力する想定: outputs["label_APorPA"] -> shape=[1, 3] (3クラス: AP, PA, Lateral) outputs["label_round"] -> shape=[1, 4] (4クラス: Upright, Inverted, Left rotation, Right rotation) """ image_handler = ImageHandler(args_dataloader) image_tensor = image_handler.set_image(image) with torch.no_grad(): outputs = model(image_tensor) # デバッグ用の出力チェック print("keys in outputs =", outputs.keys()) if "label_APorPA" in outputs: print("label_APorPA shape =", outputs["label_APorPA"].shape) if "label_round" in outputs: print("label_round shape =", outputs["label_round"].shape) # --- label_APorPA --- if "label_APorPA" not in outputs: print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}") return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'" scores_APorPA = outputs["label_APorPA"] # shape=[1,3] pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item() predicted_APorPA = LABEL_APorPA[pred_APorPA_idx] # --- label_round --- if "label_round" not in outputs: print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}") return predicted_APorPA, "ERROR: Missing 'label_round'" scores_round = outputs["label_round"] # shape=[1,4] pred_round_idx = torch.argmax(scores_round, dim=1).item() predicted_round = LABEL_ROUND[pred_round_idx] return predicted_APorPA, predicted_round # =========================================== # 7) Gradio UI # =========================================== html_content = """
The input image should be a 256×256 (grayscale) PNG file.
This model predict both the imaging projection (3 classes: AP, PA, Lateral) and rotation (4 classes: Upright, Inverted, Left rotation, Right rotation) for chest radiographs.