MedicalAILabo's picture
Update app.py
543a456 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import gradio as gr
import numpy as np
import torchvision.transforms as T
from lib.framework import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin
# ===========================================
# 1) パスなど(修正があれば適宜変更)
# ===========================================
test_weight = './weight_epoch-011_best.pt'
parameter = './parameters.json'
# ===========================================
# 2) クラスラベルの定義
# - label_APorPA (3クラス): 0=AP, 1=PA, 2=Lateral
# - label_round (4クラス): 0=Upright, 1=Inverted, 2=Left rotation, 3=Right rotation
# ===========================================
LABEL_APorPA = [
"AP", # class 0
"PA", # class 1
"Lateral", # class 2
]
LABEL_ROUND = [
"Upright", # class 0
"Inverted", # class 1
"Left rotation", # class 2
"Right rotation" # class 3
]
# ===========================================
# 3) 前処理用の ImageHandlerクラス
# - 画像が既に256×256前提
# ===========================================
class ImageHandler(ImageMixin):
def __init__(self, params):
self.params = params
self.transform = T.Compose([
# T.Resize((256, 256)), # 必要であればコメントアウトを外す
T.ToTensor(),
])
def set_image(self, image):
image = self.transform(image)
image = {'image': image.unsqueeze(0)}
return image
# ===========================================
# 4) パラメータのロード
# ===========================================
def load_parameter(parameter):
_args = ParamSet()
params = _retrieve_parameter(parameter)
for _param, _arg in params.items():
setattr(_args, _param, _arg)
# 推論用に書き換え (学習関連は無効化または無視)
_args.augmentation = 'no'
_args.sampler = 'no'
_args.pretrained = False
_args.mlp = None
_args.net = _args.model
_args.device = torch.device('cpu')
args_model = _dispatch_by_group(_args, 'model')
args_dataloader = _dispatch_by_group(_args, 'dataloader')
return args_model, args_dataloader
args_model, args_dataloader = load_parameter(parameter)
# ===========================================
# 5) モデルを作成し学習済み重みをロード
# ===========================================
model = create_model(args_model)
print(f"Load weight: {test_weight}")
model.load_weight(test_weight)
model.eval() # 推論モード
# ===========================================
# 6) 推論関数
# ===========================================
def classify_APorPA_and_round(image):
"""
モデルが以下を出力する想定:
outputs["label_APorPA"] -> shape=[1, 3] (3クラス: AP, PA, Lateral)
outputs["label_round"] -> shape=[1, 4] (4クラス: Upright, Inverted, Left rotation, Right rotation)
"""
image_handler = ImageHandler(args_dataloader)
image_tensor = image_handler.set_image(image)
with torch.no_grad():
outputs = model(image_tensor)
# デバッグ用の出力チェック
print("keys in outputs =", outputs.keys())
if "label_APorPA" in outputs:
print("label_APorPA shape =", outputs["label_APorPA"].shape)
if "label_round" in outputs:
print("label_round shape =", outputs["label_round"].shape)
# --- label_APorPA ---
if "label_APorPA" not in outputs:
print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
scores_APorPA = outputs["label_APorPA"] # shape=[1,3]
pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
# --- label_round ---
if "label_round" not in outputs:
print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}")
return predicted_APorPA, "ERROR: Missing 'label_round'"
scores_round = outputs["label_round"] # shape=[1,4]
pred_round_idx = torch.argmax(scores_round, dim=1).item()
predicted_round = LABEL_ROUND[pred_round_idx]
return predicted_APorPA, predicted_round
# ===========================================
# 7) Gradio UI
# ===========================================
html_content = """
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
<h3>Chest X-ray: Projection & Rotation Classification</h3>
<p>The input image should be a 256×256 (grayscale) PNG file.</p>
<p>This model predict both the imaging projection (3 classes: AP, PA, Lateral) and rotation (4 classes: Upright, Inverted, Left rotation, Right rotation) for chest radiographs.</p>
</div>
"""
with gr.Blocks(title="Chest X-ray: Projection & Rotation Classification") as demo:
gr.HTML("<div style='text-align:center'><h2>Chest X-ray Projection & Rotation Classification</h2></div>")
gr.HTML(html_content)
with gr.Row():
input_image = gr.Image(type="pil", image_mode="L")
output_APorPA = gr.Label(label="Predicted AP/PA/Lateral")
output_round = gr.Label(label="Predicted Rotation")
send_btn = gr.Button("Inference")
send_btn.click(
fn=classify_APorPA_and_round,
inputs=input_image,
outputs=[output_APorPA, output_round]
)
with gr.Row():
# サンプルファイルは実際のパスに置き換えてください
gr.Examples(
examples=[
'./sample/sample_AP_inverted.png',
'./sample/sample_PA_right.png',
'./sample/sample_lateral_upright.png'
],
inputs=input_image
)
demo.launch(debug=True)