MedicalAILabo's picture
Update app.py
6e5a64b verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import gradio as gr
import numpy as np
import torchvision.transforms as T
from lib.framework import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin
# ===========================================
# 1) パスなど(修正があれば適宜変更)
# ===========================================
test_weight = './weight_epoch-003_best.pt'
parameter = './parameters.json'
# ===========================================
# 2) 7クラスに対応するラベル名 (0~6)
# ===========================================
LABELS = [
"Head", # class 0
"Neck", # class 1
"Chest", # class 2
"Incomplete Chest", # class 3
"Abdomen", # class 4
"Pelvis", # class 5
"Extremities" # class 6
]
# ===========================================
# 3) 前処理用の ImageHandlerクラス
# - 画像が既に256×256前提。Resizeはコメントアウトで残す
# ===========================================
class ImageHandler(ImageMixin):
def __init__(self, params):
self.params = params
# ここでリサイズは省略(推論が重いので)
# 入力画像は既に256×256であることを想定
self.transform = T.Compose([
# T.Resize((256, 256)), # コメントアウト: 画像を256×256にリサイズ
T.ToTensor(), # Tensor化 (0~1, shape: C,H,W)
])
def set_image(self, image):
# PIL画像 -> transform -> バッチ次元を付ける
image = self.transform(image)
image = {'image': image.unsqueeze(0)}
return image
# ===========================================
# 4) パラメータのロード (推論時にほぼ使用しない)
# ===========================================
def load_parameter(parameter):
_args = ParamSet()
params = _retrieve_parameter(parameter)
for _param, _arg in params.items():
setattr(_args, _param, _arg)
# 推論用に書き換え (学習関連は無効化または無視)
_args.augmentation = 'no'
_args.sampler = 'no'
_args.pretrained = False
_args.mlp = None
_args.net = _args.model
_args.device = torch.device('cpu')
args_model = _dispatch_by_group(_args, 'model')
args_dataloader = _dispatch_by_group(_args, 'dataloader')
return args_model, args_dataloader
args_model, args_dataloader = load_parameter(parameter)
# ===========================================
# 5) モデルを作成し学習済み重みをロード
# ===========================================
model = create_model(args_model)
print(f"Load weight: {test_weight}")
model.load_weight(test_weight)
model.eval() # 推論モード
# ===========================================
# 6) 推論関数
# ===========================================
def classify_bodypart(image):
"""
モデル出力:
outputs["label_bodypart"]
のみが存在し、そこに shape=[1,7] のスコア (logits) が入っている想定。
そのうち最大スコアのクラスを返す。
"""
image_handler = ImageHandler(args_dataloader)
image_tensor = image_handler.set_image(image)
with torch.no_grad():
outputs = model(image_tensor)
# "label_bodypart" キーの有無を確認
if "label_bodypart" not in outputs:
print(f"[ERROR] 'label_bodypart' not found in outputs. Actual keys: {list(outputs.keys())}")
return "ERROR: Missing 'label_bodypart' in output."
scores = outputs["label_bodypart"] # 例: shape=[1,7]
print("[DEBUG] shape of label_bodypart:", scores.shape)
print("[DEBUG] scores:", scores)
# argmaxを取る (shape=[1,7] 前提)
pred_idx = torch.argmax(scores, dim=1).item()
predicted_label = LABELS[pred_idx]
return predicted_label
# ===========================================
# 7) Gradio UI
# ===========================================
html_content = """
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
<h3>X-ray Bodypart Classification (7 classes)</h3>
<p>The input image should be a 256×256 (grayscale) PNG file.</p>
<p>This model predict one of the following 7 classes.:</p>
<ul>
<li>Head</li>
<li>Neck</li>
<li>Chest</li>
<li>Incomplete Chest</li>
<li>Abdomen</li>
<li>Pelvis</li>
<li>Extremities</li>
</ul>
</div>
"""
with gr.Blocks(title="X-ray Bodypart Classification") as demo:
gr.HTML("<div style='text-align:center'><h2>X-ray Bodypart Classification</h2></div>")
gr.HTML(html_content)
with gr.Row():
# グレースケール画像 (L) をそのまま入力 (256×256想定)
input_image = gr.Image(type="pil", image_mode="L")
output_label = gr.Label(label="Predicted Bodypart")
send_btn = gr.Button("Inference")
send_btn.click(fn=classify_bodypart, inputs=input_image, outputs=output_label)
with gr.Row():
# サンプルファイルのパスはご自身の配置にあわせて修正
gr.Examples(
examples=[
'./sample/sample_chest.png',
'./sample/sample_incomplete_chest.png',
'./sample/sample_abdomen.png'
],
inputs=input_image
)
demo.launch(debug=True)