Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import torchvision.transforms as T | |
| from lib.framework import create_model | |
| from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group | |
| from lib.dataloader import ImageMixin | |
| # =========================================== | |
| # 1) パスなど(修正があれば適宜変更) | |
| # =========================================== | |
| test_weight = './weight_epoch-003_best.pt' | |
| parameter = './parameters.json' | |
| # =========================================== | |
| # 2) 7クラスに対応するラベル名 (0~6) | |
| # =========================================== | |
| LABELS = [ | |
| "Head", # class 0 | |
| "Neck", # class 1 | |
| "Chest", # class 2 | |
| "Incomplete Chest", # class 3 | |
| "Abdomen", # class 4 | |
| "Pelvis", # class 5 | |
| "Extremities" # class 6 | |
| ] | |
| # =========================================== | |
| # 3) 前処理用の ImageHandlerクラス | |
| # - 画像が既に256×256前提。Resizeはコメントアウトで残す | |
| # =========================================== | |
| class ImageHandler(ImageMixin): | |
| def __init__(self, params): | |
| self.params = params | |
| # ここでリサイズは省略(推論が重いので) | |
| # 入力画像は既に256×256であることを想定 | |
| self.transform = T.Compose([ | |
| # T.Resize((256, 256)), # コメントアウト: 画像を256×256にリサイズ | |
| T.ToTensor(), # Tensor化 (0~1, shape: C,H,W) | |
| ]) | |
| def set_image(self, image): | |
| # PIL画像 -> transform -> バッチ次元を付ける | |
| image = self.transform(image) | |
| image = {'image': image.unsqueeze(0)} | |
| return image | |
| # =========================================== | |
| # 4) パラメータのロード (推論時にほぼ使用しない) | |
| # =========================================== | |
| def load_parameter(parameter): | |
| _args = ParamSet() | |
| params = _retrieve_parameter(parameter) | |
| for _param, _arg in params.items(): | |
| setattr(_args, _param, _arg) | |
| # 推論用に書き換え (学習関連は無効化または無視) | |
| _args.augmentation = 'no' | |
| _args.sampler = 'no' | |
| _args.pretrained = False | |
| _args.mlp = None | |
| _args.net = _args.model | |
| _args.device = torch.device('cpu') | |
| args_model = _dispatch_by_group(_args, 'model') | |
| args_dataloader = _dispatch_by_group(_args, 'dataloader') | |
| return args_model, args_dataloader | |
| args_model, args_dataloader = load_parameter(parameter) | |
| # =========================================== | |
| # 5) モデルを作成し学習済み重みをロード | |
| # =========================================== | |
| model = create_model(args_model) | |
| print(f"Load weight: {test_weight}") | |
| model.load_weight(test_weight) | |
| model.eval() # 推論モード | |
| # =========================================== | |
| # 6) 推論関数 | |
| # =========================================== | |
| def classify_bodypart(image): | |
| """ | |
| モデル出力: | |
| outputs["label_bodypart"] | |
| のみが存在し、そこに shape=[1,7] のスコア (logits) が入っている想定。 | |
| そのうち最大スコアのクラスを返す。 | |
| """ | |
| image_handler = ImageHandler(args_dataloader) | |
| image_tensor = image_handler.set_image(image) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| # "label_bodypart" キーの有無を確認 | |
| if "label_bodypart" not in outputs: | |
| print(f"[ERROR] 'label_bodypart' not found in outputs. Actual keys: {list(outputs.keys())}") | |
| return "ERROR: Missing 'label_bodypart' in output." | |
| scores = outputs["label_bodypart"] # 例: shape=[1,7] | |
| print("[DEBUG] shape of label_bodypart:", scores.shape) | |
| print("[DEBUG] scores:", scores) | |
| # argmaxを取る (shape=[1,7] 前提) | |
| pred_idx = torch.argmax(scores, dim=1).item() | |
| predicted_label = LABELS[pred_idx] | |
| return predicted_label | |
| # =========================================== | |
| # 7) Gradio UI | |
| # =========================================== | |
| html_content = """ | |
| <div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;"> | |
| <h3>X-ray Bodypart Classification (7 classes)</h3> | |
| <p>The input image should be a 256×256 (grayscale) PNG file.</p> | |
| <p>This model predict one of the following 7 classes.:</p> | |
| <ul> | |
| <li>Head</li> | |
| <li>Neck</li> | |
| <li>Chest</li> | |
| <li>Incomplete Chest</li> | |
| <li>Abdomen</li> | |
| <li>Pelvis</li> | |
| <li>Extremities</li> | |
| </ul> | |
| </div> | |
| """ | |
| with gr.Blocks(title="X-ray Bodypart Classification") as demo: | |
| gr.HTML("<div style='text-align:center'><h2>X-ray Bodypart Classification</h2></div>") | |
| gr.HTML(html_content) | |
| with gr.Row(): | |
| # グレースケール画像 (L) をそのまま入力 (256×256想定) | |
| input_image = gr.Image(type="pil", image_mode="L") | |
| output_label = gr.Label(label="Predicted Bodypart") | |
| send_btn = gr.Button("Inference") | |
| send_btn.click(fn=classify_bodypart, inputs=input_image, outputs=output_label) | |
| with gr.Row(): | |
| # サンプルファイルのパスはご自身の配置にあわせて修正 | |
| gr.Examples( | |
| examples=[ | |
| './sample/sample_chest.png', | |
| './sample/sample_incomplete_chest.png', | |
| './sample/sample_abdomen.png' | |
| ], | |
| inputs=input_image | |
| ) | |
| demo.launch(debug=True) | |