|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image |
|
|
|
|
|
from lib.framework import create_model |
|
|
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group |
|
|
from lib.dataloader import ImageMixin |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WEIGHT_PATH = "./cxp_projection_rotation.pt" |
|
|
PARAMETER_JSON = "./parameters.json" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LABEL_APorPA = ["AP", "PA", "Lateral"] |
|
|
LABEL_ROUND = ["Upright", "Inverted", "Left-rotation", "Right-rotation"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageHandler(ImageMixin): |
|
|
def __init__(self, params): |
|
|
self.params = params |
|
|
self.transform = T.Compose([ |
|
|
T.ToTensor(), |
|
|
]) |
|
|
|
|
|
def set_image(self, image: Image.Image): |
|
|
tensor = self.transform(image) |
|
|
return {"image": tensor.unsqueeze(0)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_parameter(parameter_path): |
|
|
_args = ParamSet() |
|
|
params = _retrieve_parameter(parameter_path) |
|
|
for k, v in params.items(): |
|
|
setattr(_args, k, v) |
|
|
|
|
|
_args.augmentation = "no" |
|
|
_args.sampler = "no" |
|
|
_args.pretrained = False |
|
|
_args.mlp = None |
|
|
_args.net = _args.model |
|
|
_args.device = torch.device("cpu") |
|
|
return ( |
|
|
_dispatch_by_group(_args, "model"), |
|
|
_dispatch_by_group(_args, "dataloader"), |
|
|
) |
|
|
|
|
|
args_model, args_dataloader = load_parameter(PARAMETER_JSON) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = create_model(args_model) |
|
|
print(f"Loading weights from {WEIGHT_PATH}") |
|
|
model.load_weight(WEIGHT_PATH) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_html(image_path: str) -> str: |
|
|
|
|
|
img = Image.open(image_path).convert("L") |
|
|
handler = ImageHandler(args_dataloader) |
|
|
batch = handler.set_image(img) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(batch) |
|
|
logits_proj = outputs.get("label_APorPA") |
|
|
logits_rot = outputs.get("label_round") |
|
|
|
|
|
|
|
|
probs_proj = F.softmax(logits_proj, dim=1)[0].cpu().numpy() |
|
|
probs_rot = F.softmax(logits_rot, dim=1)[0].cpu().numpy() |
|
|
|
|
|
|
|
|
idx_proj = int(probs_proj.argmax()) |
|
|
idx_rot = int(probs_rot.argmax()) |
|
|
pred_proj = LABEL_APorPA[idx_proj] |
|
|
pred_rot = LABEL_ROUND[idx_rot] |
|
|
conf_proj = float(probs_proj[idx_proj]) |
|
|
conf_rot = float(probs_rot[idx_rot]) |
|
|
|
|
|
|
|
|
base = os.path.splitext(os.path.basename(image_path))[0] |
|
|
parts = base.split("_") |
|
|
if len(parts) >= 3: |
|
|
orig_proj = parts[1] |
|
|
orig_rot = parts[2] |
|
|
else: |
|
|
orig_proj = orig_rot = None |
|
|
|
|
|
|
|
|
def make_warning(kind, orig, pred, conf): |
|
|
|
|
|
high_thr = 0.8 |
|
|
med_thr = 0.5 |
|
|
if orig and orig != pred: |
|
|
if conf >= high_thr: |
|
|
return ( |
|
|
f"<p style='color:red'>⚠ Potentially mislabeled {kind}: " |
|
|
f"filename says {orig}, model predicts {pred} (confidence {conf:.2f})</p>" |
|
|
) |
|
|
elif conf >= med_thr: |
|
|
return ( |
|
|
f"<p style='color:orange'>⚠ There is a possibility of mislabeled {kind}: " |
|
|
f"model predicts {pred} with moderate confidence ({conf:.2f})</p>" |
|
|
) |
|
|
if conf < med_thr: |
|
|
return ( |
|
|
f"<p style='color:orange'>⚠ Low confidence for {kind} ({conf:.2f}); " |
|
|
f"please check image quality or framing.</p>" |
|
|
) |
|
|
return "" |
|
|
|
|
|
|
|
|
warn_html = "" |
|
|
warn_html += make_warning("projection", orig_proj, pred_proj, conf_proj) |
|
|
warn_html += make_warning("rotation", orig_rot, pred_rot, conf_rot) |
|
|
|
|
|
|
|
|
scores_proj = ", ".join( |
|
|
f"{LABEL_APorPA[i]}: {p:.2f}" for i, p in enumerate(probs_proj) |
|
|
) |
|
|
scores_rot = ", ".join( |
|
|
f"{LABEL_ROUND[i]}: {p:.2f}" for i, p in enumerate(probs_rot) |
|
|
) |
|
|
|
|
|
|
|
|
html = ( |
|
|
f"<p><strong>Projection :</strong> {pred_proj} " |
|
|
f"<small>({scores_proj})</small></p>" |
|
|
f"<p><strong>Rotation :</strong> {pred_rot} " |
|
|
f"<small>({scores_rot})</small></p>" |
|
|
f"{warn_html}" |
|
|
) |
|
|
return html |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
html_header = """ |
|
|
<div style="padding:10px;border:1px solid #ddd;border-radius:5px"> |
|
|
<h2>Chest X‑ray Projection & Rotation Classification</h2> |
|
|
<p>Upload a 256×256 grayscale PNG. The model predicts projection (AP/PA/Lateral) |
|
|
and rotation (Upright/Inverted/Left/Right) and shows softmax confidences. |
|
|
It warns if filename label differs or if confidence is low. Please name the files using the format: [Number]_projection_rotation.png. |
|
|
For the projection part of the filename, please use one of the following three terms: AP/PA/Lateral |
|
|
For the rotation part of the filename, please use one of the following four terms: Upright/Inverted/Left90/Right90 |
|
|
As samples, We have prepared two sets of images: A PA view in the Upright position. An AP view with Left rotation. For each set, we have created two versions, one of which includes a mislabel in either its projection or rotation tag.</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="CXR Projection & Rotation") as demo: |
|
|
gr.HTML(html_header) |
|
|
|
|
|
with gr.Row(): |
|
|
input_image = gr.Image( |
|
|
label="Upload PNG (256×256)", |
|
|
type="filepath", |
|
|
image_mode="L" |
|
|
) |
|
|
output_html = gr.HTML() |
|
|
|
|
|
send_btn = gr.Button("Run Inference") |
|
|
send_btn.click( |
|
|
fn=predict_html, |
|
|
inputs=input_image, |
|
|
outputs=output_html |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"./sample/1_AP_Upright.png", |
|
|
"./sample/1_PA_Inverted.png", |
|
|
"./sample/2_AP_Right-rotation.png", |
|
|
"./sample/2_Lateral_Left-rotation.png", |
|
|
], |
|
|
inputs=input_image |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
"**Sample filenames:** 𝚮\n" |
|
|
"- 1_AP_Upright.png \n" |
|
|
"- 1_PA_Inverted.png \n" |
|
|
"- 2_AP_Right90.png \n" |
|
|
"- 2_Lateral_Left90.png" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(debug=True) |
|
|
|