#!/usr/bin/env python # -*- coding: utf-8 -*- import os import torch import torch.nn.functional as F import gradio as gr import numpy as np import torchvision.transforms as T from PIL import Image from lib.framework import create_model from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group from lib.dataloader import ImageMixin # =========================================== # 1) パス設定 # =========================================== WEIGHT_PATH = "./cxp_projection_rotation.pt" PARAMETER_JSON = "./parameters.json" # =========================================== # 2) クラスラベル定義 # =========================================== LABEL_APorPA = ["AP", "PA", "Lateral"] LABEL_ROUND = ["Upright", "Inverted", "Left-rotation", "Right-rotation"] # =========================================== # 3) 前処理クラス # =========================================== class ImageHandler(ImageMixin): def __init__(self, params): self.params = params self.transform = T.Compose([ T.ToTensor(), ]) def set_image(self, image: Image.Image): tensor = self.transform(image) # [C,H,W], float32 in [0,1] return {"image": tensor.unsqueeze(0)} # バッチ次元追加 # =========================================== # 4) パラメータロード # =========================================== def load_parameter(parameter_path): _args = ParamSet() params = _retrieve_parameter(parameter_path) for k, v in params.items(): setattr(_args, k, v) # 推論用に上書き _args.augmentation = "no" _args.sampler = "no" _args.pretrained = False _args.mlp = None _args.net = _args.model _args.device = torch.device("cpu") return ( _dispatch_by_group(_args, "model"), _dispatch_by_group(_args, "dataloader"), ) args_model, args_dataloader = load_parameter(PARAMETER_JSON) # =========================================== # 5) モデル作成&重みロード # =========================================== model = create_model(args_model) print(f"Loading weights from {WEIGHT_PATH}") model.load_weight(WEIGHT_PATH) model.eval() # =========================================== # 6) 推論+HTML生成 # =========================================== def predict_html(image_path: str) -> str: # 画像読み込み img = Image.open(image_path).convert("L") handler = ImageHandler(args_dataloader) batch = handler.set_image(img) with torch.no_grad(): outputs = model(batch) logits_proj = outputs.get("label_APorPA") logits_rot = outputs.get("label_round") # softmax で確率に変換 probs_proj = F.softmax(logits_proj, dim=1)[0].cpu().numpy() probs_rot = F.softmax(logits_rot, dim=1)[0].cpu().numpy() # argmax でラベル選択 idx_proj = int(probs_proj.argmax()) idx_rot = int(probs_rot.argmax()) pred_proj = LABEL_APorPA[idx_proj] pred_rot = LABEL_ROUND[idx_rot] conf_proj = float(probs_proj[idx_proj]) conf_rot = float(probs_rot[idx_rot]) # ファイル名から元ラベル取得(例: "1_AP_Upright.png" → orig_proj="AP", orig_rot="Upright") base = os.path.splitext(os.path.basename(image_path))[0] parts = base.split("_") if len(parts) >= 3: orig_proj = parts[1] orig_rot = parts[2] else: orig_proj = orig_rot = None # 警告HTML作成用ヘルパー def make_warning(kind, orig, pred, conf): # kind: "projection" or "rotation" high_thr = 0.8 med_thr = 0.5 if orig and orig != pred: if conf >= high_thr: return ( f"
⚠ Potentially mislabeled {kind}: " f"filename says {orig}, model predicts {pred} (confidence {conf:.2f})
" ) elif conf >= med_thr: return ( f"⚠ There is a possibility of mislabeled {kind}: " f"model predicts {pred} with moderate confidence ({conf:.2f})
" ) if conf < med_thr: return ( f"⚠ Low confidence for {kind} ({conf:.2f}); " f"please check image quality or framing.
" ) return "" # 警告HTML warn_html = "" warn_html += make_warning("projection", orig_proj, pred_proj, conf_proj) warn_html += make_warning("rotation", orig_rot, pred_rot, conf_rot) # クラスごとのスコア表示用HTML scores_proj = ", ".join( f"{LABEL_APorPA[i]}: {p:.2f}" for i, p in enumerate(probs_proj) ) scores_rot = ", ".join( f"{LABEL_ROUND[i]}: {p:.2f}" for i, p in enumerate(probs_rot) ) # 結果表示用HTML html = ( f"Projection : {pred_proj} " f"({scores_proj})
" f"Rotation : {pred_rot} " f"({scores_rot})
" f"{warn_html}" ) return html # =========================================== # 7) Gradio UI # =========================================== html_header = """Upload a 256×256 grayscale PNG. The model predicts projection (AP/PA/Lateral) and rotation (Upright/Inverted/Left/Right) and shows softmax confidences. It warns if filename label differs or if confidence is low. Please name the files using the format: [Number]_projection_rotation.png. For the projection part of the filename, please use one of the following three terms: AP/PA/Lateral For the rotation part of the filename, please use one of the following four terms: Upright/Inverted/Left90/Right90 As samples, We have prepared two sets of images: A PA view in the Upright position. An AP view with Left rotation. For each set, we have created two versions, one of which includes a mislabel in either its projection or rotation tag.