Clocksp's picture
Update app.py
e50fb6a verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image, ImageOps
import numpy as np
import gradio as gr
import os
import pandas as pd
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_op = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv_op(x)
class Downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = DoubleConv(in_channels, out_channels)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
down = self.conv(x)
p = self.pool(down)
return down, p
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# handle spatial mismatches
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=1):
super().__init__()
self.down1 = Downsample(in_channels, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.down4 = Downsample(256, 512)
self.bottleneck = DoubleConv(512, 1024)
self.up1 = UpSample(1024, 512)
self.up2 = UpSample(512, 256)
self.up3 = UpSample(256, 128)
self.up4 = UpSample(128, 64)
self.out = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x):
d1, p1 = self.down1(x)
d2, p2 = self.down2(p1)
d3, p3 = self.down3(p2)
d4, p4 = self.down4(p3)
b = self.bottleneck(p4)
u1 = self.up1(b, d4)
u2 = self.up2(u1, d3)
u3 = self.up3(u2, d2)
u4 = self.up4(u3, d1)
return self.out(u4)
def build_efficientnet_b3(num_output=2, pretrained=False):
# torchvision efficientnet_b3; weights=None or pretrained control
model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1)
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, num_output)
return model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
UNET_PATH = "models/unet.pth"
MODEL_BACT_PATH = "models/model_bacterial.pt"
MODEL_VIRAL_PATH = "models/model_viral.pt"
unet = UNet(in_channels=3, num_classes=1).to(device)
unet.load_state_dict(torch.load(UNET_PATH, map_location=device))
unet.eval()
model_bact = build_efficientnet_b3(num_output=2).to(device)
model_viral = build_efficientnet_b3(num_output=2).to(device)
model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device))
model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device))
model_bact.eval()
model_viral.eval()
preprocess_unet = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
])
preprocess_classifier = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
def infer_mask_and_mask_image(pil_img, threshold=0.5):
"""
Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL)
"""
# Ensure RGB
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
# UNet input: tensor
inp = preprocess_unet(pil_img).unsqueeze(0).to(device)
with torch.no_grad():
logits = unet(inp)
mask_prob = torch.sigmoid(logits)[0,0]
mask_np = mask_prob.cpu().numpy()
# binary mask
bin_mask = (mask_np >= threshold).astype(np.uint8)
# apply mask to original image (resized to 300x300) for classifier
img_tensor = preprocess_classifier(pil_img).to(device) # normalized
# the mask corresponds to preprocess_unet size (300,300) same as classifier
mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float()
masked_img_tensor = img_tensor * mask_tensor
# convert masked tensor back to PIL for display (unnormalize)
img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0)
masked_display = (img_for_display * bin_mask[...,None])
masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8)
masked_pil = Image.fromarray(masked_display)
return masked_img_tensor, mask_np, masked_pil
def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
"""
masked_img_tensor: C,H,W on device, normalized for classifier
Returns (pb, pv, label)
pb = probability pneumonia in bacterial model
pv = probability pneumonia in viral model
"""
x = masked_img_tensor.unsqueeze(0).to(device)
with torch.no_grad():
out_b = model_bact(x)
out_v = model_viral(x)
pb = torch.softmax(out_b, dim=1)[0,1].item()
pv = torch.softmax(out_v, dim=1)[0,1].item()
# ----------- DECISION LOGIC -----------
# Case 1: Both low → NORMAL
if pb < thresh_b and pv < thresh_v:
label = "NORMAL"
# Case 2: Only bacterial high → BACTERIAL
elif pb >= thresh_b and pv < thresh_v:
label = "BACTERIAL PNEUMONIA"
# Case 3: Only viral high → VIRAL
elif pv >= thresh_v and pb < thresh_b:
label = "VIRAL PNEUMONIA"
# Case 4: Both high → pick the dominant type
else:
label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA"
label += " (fallback-high-confidence-overlap)"
return pb, pv, label
def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
"""
Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL)
"""
pil = Image.fromarray(img.astype('uint8'), 'RGB')
masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(
pil, threshold=seg_thresh
)
pb, pv, pred_label = classify_masked_tensor(
masked_tensor,
thresh_b=thresh_b,
thresh_v=thresh_v
)
# Convert mask to PIL
mask_vis = (mask_np * 255).astype(np.uint8)
mask_pil = Image.fromarray(mask_vis).convert("L")
# Resize original for overlay
display_orig = pil.resize((300, 300))
# Create red mask overlay
red_mask = np.zeros((300, 300, 3), dtype=np.uint8)
red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2)
red_mask = Image.fromarray(red_mask).convert("RGBA")
alpha = (mask_np * 120).astype(np.uint8)
red_mask.putalpha(Image.fromarray(alpha))
blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
return (
pred_label,
float(pb),
float(pv),
masked_pil,
blended
)
title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
desc = "Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). " \
"If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable."
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}\n{desc}")
with gr.Row():
with gr.Column():
img_in = gr.Image(type="numpy", label="Upload chest X-ray")
thresh_b = gr.Slider(0.1, 0.9, 0.5, step=0.01, label="Bacterial threshold")
thresh_v = gr.Slider(0.1, 0.9, 0.5, step=0.01, label="Viral threshold")
seg_thresh = gr.Slider(0.1, 0.9, 0.5, step=0.01, label="Segmentation mask threshold")
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column():
pred_out = gr.Label(num_top_classes=1, label="Prediction")
pb_out = gr.Number(label="Bacterial Probability")
pv_out = gr.Number(label="Viral Probability")
masked_img_out = gr.Image(type="pil", label="Masked Image")
overlay_out = gr.Image(type="pil", label="Segmentation Overlay")
submit_btn.click(
inference_pipeline,
inputs=[img_in, thresh_b, thresh_v, seg_thresh],
outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out]
)
clear_btn.click(
lambda: (None, None, None, None, None, None),
outputs=[img_in, pred_out, pb_out, pv_out, masked_img_out, overlay_out]
)
img_in.change(
inference_pipeline,
inputs=[img_in, thresh_b, thresh_v, seg_thresh],
outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out]
)
gr.Markdown("## Test Samples")
# with gr.Row():
# # NORMAL
# with gr.Column(scale=1):
# gr.Markdown("### NORMAL")
# gr.Image("images/NORMAL.jpeg", show_label=False, height=220)
# # VIRAL
# with gr.Column(scale=1):
# gr.Markdown("### VIRAL")
# gr.Image("images/VIRAL.jpeg", show_label=False, height=220)
# # BACTERIAL
# with gr.Column(scale=1):
# gr.Markdown("### BACTERIAL")
# gr.Image("images/BACT.jpeg", show_label=False, height=220)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### NORMAL")
normal_sample = gr.Image("images/NORMAL.jpeg", show_label=False, height=220, interactive=True)
with gr.Column(scale=1):
gr.Markdown("### VIRAL")
viral_sample = gr.Image("images/VIRAL.jpeg", show_label=False, height=220, interactive=True)
with gr.Column(scale=1):
gr.Markdown("### BACTERIAL")
bact_sample = gr.Image("images/BACT.jpeg", show_label=False, height=220, interactive=True)
normal_sample.select(
inference_pipeline,
inputs=[normal_sample, thresh_b, thresh_v, seg_thresh],
outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out]
)
viral_sample.select(
inference_pipeline,
inputs=[viral_sample, thresh_b, thresh_v, seg_thresh],
outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out]
)
bact_sample.select(
inference_pipeline,
inputs=[bact_sample, thresh_b, thresh_v, seg_thresh],
outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out]
)
if __name__ == "__main__":
demo.launch()