|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision import transforms, models |
|
|
from PIL import Image, ImageOps |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import os |
|
|
import pandas as pd |
|
|
|
|
|
class DoubleConv(nn.Module): |
|
|
def __init__(self, in_channels, out_channels): |
|
|
super().__init__() |
|
|
self.conv_op = nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
def forward(self, x): |
|
|
return self.conv_op(x) |
|
|
|
|
|
class Downsample(nn.Module): |
|
|
def __init__(self, in_channels, out_channels): |
|
|
super().__init__() |
|
|
self.conv = DoubleConv(in_channels, out_channels) |
|
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
def forward(self, x): |
|
|
down = self.conv(x) |
|
|
p = self.pool(down) |
|
|
return down, p |
|
|
|
|
|
class UpSample(nn.Module): |
|
|
def __init__(self, in_channels, out_channels): |
|
|
super().__init__() |
|
|
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) |
|
|
self.conv = DoubleConv(in_channels, out_channels) |
|
|
def forward(self, x1, x2): |
|
|
x1 = self.up(x1) |
|
|
|
|
|
diffY = x2.size()[2] - x1.size()[2] |
|
|
diffX = x2.size()[3] - x1.size()[3] |
|
|
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, |
|
|
diffY // 2, diffY - diffY // 2]) |
|
|
x = torch.cat([x2, x1], dim=1) |
|
|
return self.conv(x) |
|
|
|
|
|
class UNet(nn.Module): |
|
|
def __init__(self, in_channels=3, num_classes=1): |
|
|
super().__init__() |
|
|
self.down1 = Downsample(in_channels, 64) |
|
|
self.down2 = Downsample(64, 128) |
|
|
self.down3 = Downsample(128, 256) |
|
|
self.down4 = Downsample(256, 512) |
|
|
self.bottleneck = DoubleConv(512, 1024) |
|
|
self.up1 = UpSample(1024, 512) |
|
|
self.up2 = UpSample(512, 256) |
|
|
self.up3 = UpSample(256, 128) |
|
|
self.up4 = UpSample(128, 64) |
|
|
self.out = nn.Conv2d(64, num_classes, kernel_size=1) |
|
|
def forward(self, x): |
|
|
d1, p1 = self.down1(x) |
|
|
d2, p2 = self.down2(p1) |
|
|
d3, p3 = self.down3(p2) |
|
|
d4, p4 = self.down4(p3) |
|
|
b = self.bottleneck(p4) |
|
|
u1 = self.up1(b, d4) |
|
|
u2 = self.up2(u1, d3) |
|
|
u3 = self.up3(u2, d2) |
|
|
u4 = self.up4(u3, d1) |
|
|
return self.out(u4) |
|
|
|
|
|
|
|
|
def build_efficientnet_b3(num_output=2, pretrained=False): |
|
|
|
|
|
model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1) |
|
|
in_features = model.classifier[1].in_features |
|
|
model.classifier[1] = nn.Linear(in_features, num_output) |
|
|
return model |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print("Using device:", device) |
|
|
|
|
|
UNET_PATH = "models/unet.pth" |
|
|
MODEL_BACT_PATH = "models/model_bacterial.pt" |
|
|
MODEL_VIRAL_PATH = "models/model_viral.pt" |
|
|
|
|
|
|
|
|
unet = UNet(in_channels=3, num_classes=1).to(device) |
|
|
unet.load_state_dict(torch.load(UNET_PATH, map_location=device)) |
|
|
unet.eval() |
|
|
|
|
|
|
|
|
model_bact = build_efficientnet_b3(num_output=2).to(device) |
|
|
model_viral = build_efficientnet_b3(num_output=2).to(device) |
|
|
|
|
|
model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device)) |
|
|
model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device)) |
|
|
|
|
|
model_bact.eval() |
|
|
model_viral.eval() |
|
|
|
|
|
|
|
|
preprocess_unet = transforms.Compose([ |
|
|
transforms.Resize((300, 300)), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
preprocess_classifier = transforms.Compose([ |
|
|
transforms.Resize((300, 300)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) |
|
|
]) |
|
|
|
|
|
def infer_mask_and_mask_image(pil_img, threshold=0.5): |
|
|
""" |
|
|
Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL) |
|
|
""" |
|
|
|
|
|
if pil_img.mode != "RGB": |
|
|
pil_img = pil_img.convert("RGB") |
|
|
|
|
|
inp = preprocess_unet(pil_img).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
logits = unet(inp) |
|
|
mask_prob = torch.sigmoid(logits)[0,0] |
|
|
mask_np = mask_prob.cpu().numpy() |
|
|
|
|
|
bin_mask = (mask_np >= threshold).astype(np.uint8) |
|
|
|
|
|
img_tensor = preprocess_classifier(pil_img).to(device) |
|
|
|
|
|
mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float() |
|
|
masked_img_tensor = img_tensor * mask_tensor |
|
|
|
|
|
img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0) |
|
|
masked_display = (img_for_display * bin_mask[...,None]) |
|
|
masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8) |
|
|
masked_pil = Image.fromarray(masked_display) |
|
|
return masked_img_tensor, mask_np, masked_pil |
|
|
|
|
|
def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5): |
|
|
""" |
|
|
masked_img_tensor: C,H,W on device, normalized for classifier |
|
|
Returns (pb, pv, label) |
|
|
pb = probability pneumonia in bacterial model |
|
|
pv = probability pneumonia in viral model |
|
|
""" |
|
|
x = masked_img_tensor.unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out_b = model_bact(x) |
|
|
out_v = model_viral(x) |
|
|
|
|
|
pb = torch.softmax(out_b, dim=1)[0,1].item() |
|
|
pv = torch.softmax(out_v, dim=1)[0,1].item() |
|
|
|
|
|
|
|
|
|
|
|
if pb < thresh_b and pv < thresh_v: |
|
|
label = "NORMAL" |
|
|
|
|
|
|
|
|
elif pb >= thresh_b and pv < thresh_v: |
|
|
label = "BACTERIAL PNEUMONIA" |
|
|
|
|
|
|
|
|
elif pv >= thresh_v and pb < thresh_b: |
|
|
label = "VIRAL PNEUMONIA" |
|
|
|
|
|
|
|
|
else: |
|
|
label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA" |
|
|
label += " (fallback-high-confidence-overlap)" |
|
|
|
|
|
return pb, pv, label |
|
|
|
|
|
|
|
|
|
|
|
def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5): |
|
|
""" |
|
|
Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL) |
|
|
""" |
|
|
|
|
|
pil = Image.fromarray(img.astype('uint8'), 'RGB') |
|
|
|
|
|
masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image( |
|
|
pil, threshold=seg_thresh |
|
|
) |
|
|
|
|
|
pb, pv, pred_label = classify_masked_tensor( |
|
|
masked_tensor, |
|
|
thresh_b=thresh_b, |
|
|
thresh_v=thresh_v |
|
|
) |
|
|
|
|
|
|
|
|
mask_vis = (mask_np * 255).astype(np.uint8) |
|
|
mask_pil = Image.fromarray(mask_vis).convert("L") |
|
|
|
|
|
|
|
|
display_orig = pil.resize((300, 300)) |
|
|
|
|
|
|
|
|
red_mask = np.zeros((300, 300, 3), dtype=np.uint8) |
|
|
red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2) |
|
|
red_mask = Image.fromarray(red_mask).convert("RGBA") |
|
|
|
|
|
alpha = (mask_np * 120).astype(np.uint8) |
|
|
red_mask.putalpha(Image.fromarray(alpha)) |
|
|
|
|
|
blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask) |
|
|
|
|
|
|
|
|
return ( |
|
|
pred_label, |
|
|
float(pb), |
|
|
float(pv), |
|
|
masked_pil, |
|
|
blended |
|
|
) |
|
|
|
|
|
title = "Chest X-ray: UNet segmentation + 2 binary classifiers" |
|
|
desc = "Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). " \ |
|
|
"If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable." |
|
|
|
|
|
with gr.Blocks(title=title) as demo: |
|
|
|
|
|
gr.Markdown(f"## {title}\n{desc}") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
img_in = gr.Image(type="numpy", label="Upload chest X-ray") |
|
|
thresh_b = gr.Slider(0.1, 0.9, 0.5, step=0.01, label="Bacterial threshold") |
|
|
thresh_v = gr.Slider(0.1, 0.9, 0.5, step=0.01, label="Viral threshold") |
|
|
seg_thresh = gr.Slider(0.1, 0.9, 0.5, step=0.01, label="Segmentation mask threshold") |
|
|
|
|
|
submit_btn = gr.Button("Submit", variant="primary") |
|
|
clear_btn = gr.Button("Clear", variant="secondary") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
pred_out = gr.Label(num_top_classes=1, label="Prediction") |
|
|
pb_out = gr.Number(label="Bacterial Probability") |
|
|
pv_out = gr.Number(label="Viral Probability") |
|
|
masked_img_out = gr.Image(type="pil", label="Masked Image") |
|
|
overlay_out = gr.Image(type="pil", label="Segmentation Overlay") |
|
|
|
|
|
submit_btn.click( |
|
|
inference_pipeline, |
|
|
inputs=[img_in, thresh_b, thresh_v, seg_thresh], |
|
|
outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out] |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
lambda: (None, None, None, None, None, None), |
|
|
outputs=[img_in, pred_out, pb_out, pv_out, masked_img_out, overlay_out] |
|
|
) |
|
|
|
|
|
img_in.change( |
|
|
inference_pipeline, |
|
|
inputs=[img_in, thresh_b, thresh_v, seg_thresh], |
|
|
outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out] |
|
|
) |
|
|
|
|
|
gr.Markdown("## Test Samples") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### NORMAL") |
|
|
normal_sample = gr.Image("images/NORMAL.jpeg", show_label=False, height=220, interactive=True) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### VIRAL") |
|
|
viral_sample = gr.Image("images/VIRAL.jpeg", show_label=False, height=220, interactive=True) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### BACTERIAL") |
|
|
bact_sample = gr.Image("images/BACT.jpeg", show_label=False, height=220, interactive=True) |
|
|
|
|
|
normal_sample.select( |
|
|
inference_pipeline, |
|
|
inputs=[normal_sample, thresh_b, thresh_v, seg_thresh], |
|
|
outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out] |
|
|
) |
|
|
|
|
|
viral_sample.select( |
|
|
inference_pipeline, |
|
|
inputs=[viral_sample, thresh_b, thresh_v, seg_thresh], |
|
|
outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out] |
|
|
) |
|
|
|
|
|
bact_sample.select( |
|
|
inference_pipeline, |
|
|
inputs=[bact_sample, thresh_b, thresh_v, seg_thresh], |
|
|
outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|