Clocksp's picture
Update app.py
b889c28 verified
raw
history blame
10.4 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image, ImageOps
import numpy as np
import gradio as gr
import os
import pandas as pd
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_op = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv_op(x)
class Downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = DoubleConv(in_channels, out_channels)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
down = self.conv(x)
p = self.pool(down)
return down, p
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# handle spatial mismatches
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=1):
super().__init__()
self.down1 = Downsample(in_channels, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.down4 = Downsample(256, 512)
self.bottleneck = DoubleConv(512, 1024)
self.up1 = UpSample(1024, 512)
self.up2 = UpSample(512, 256)
self.up3 = UpSample(256, 128)
self.up4 = UpSample(128, 64)
self.out = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x):
d1, p1 = self.down1(x)
d2, p2 = self.down2(p1)
d3, p3 = self.down3(p2)
d4, p4 = self.down4(p3)
b = self.bottleneck(p4)
u1 = self.up1(b, d4)
u2 = self.up2(u1, d3)
u3 = self.up3(u2, d2)
u4 = self.up4(u3, d1)
return self.out(u4)
def build_efficientnet_b3(num_output=2, pretrained=False):
# torchvision efficientnet_b3; weights=None or pretrained control
model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1)
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, num_output)
return model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
UNET_PATH = "models/unet.pth"
MODEL_BACT_PATH = "models/model_bacterial.pt"
MODEL_VIRAL_PATH = "models/model_viral.pt"
unet = UNet(in_channels=3, num_classes=1).to(device)
unet.load_state_dict(torch.load(UNET_PATH, map_location=device))
unet.eval()
model_bact = build_efficientnet_b3(num_output=2).to(device)
model_viral = build_efficientnet_b3(num_output=2).to(device)
model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device))
model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device))
model_bact.eval()
model_viral.eval()
preprocess_unet = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
])
preprocess_classifier = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
def infer_mask_and_mask_image(pil_img, threshold=0.5):
"""
Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL)
"""
# Ensure RGB
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
# UNet input: tensor
inp = preprocess_unet(pil_img).unsqueeze(0).to(device)
with torch.no_grad():
logits = unet(inp)
mask_prob = torch.sigmoid(logits)[0,0]
mask_np = mask_prob.cpu().numpy()
# binary mask
bin_mask = (mask_np >= threshold).astype(np.uint8)
# apply mask to original image (resized to 300x300) for classifier
img_tensor = preprocess_classifier(pil_img).to(device) # normalized
# the mask corresponds to preprocess_unet size (300,300) same as classifier
mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float()
masked_img_tensor = img_tensor * mask_tensor
# convert masked tensor back to PIL for display (unnormalize)
img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0)
masked_display = (img_for_display * bin_mask[...,None])
masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8)
masked_pil = Image.fromarray(masked_display)
return masked_img_tensor, mask_np, masked_pil
def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
"""
masked_img_tensor: C,H,W on device, normalized for classifier
Returns (pb, pv, label)
pb = probability pneumonia in bacterial model
pv = probability pneumonia in viral model
"""
x = masked_img_tensor.unsqueeze(0).to(device)
with torch.no_grad():
out_b = model_bact(x)
out_v = model_viral(x)
pb = torch.softmax(out_b, dim=1)[0,1].item()
pv = torch.softmax(out_v, dim=1)[0,1].item()
# ----------- DECISION LOGIC -----------
# Case 1: Both low → NORMAL
if pb < thresh_b and pv < thresh_v:
label = "NORMAL"
# Case 2: Only bacterial high → BACTERIAL
elif pb >= thresh_b and pv < thresh_v:
label = "BACTERIAL PNEUMONIA"
# Case 3: Only viral high → VIRAL
elif pv >= thresh_v and pb < thresh_b:
label = "VIRAL PNEUMONIA"
# Case 4: Both high → pick the dominant type
else:
label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA"
label += " (fallback-high-confidence-overlap)"
return pb, pv, label
def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
"""
Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL)
"""
pil = Image.fromarray(img.astype('uint8'), 'RGB')
masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(
pil, threshold=seg_thresh
)
pb, pv, pred_label = classify_masked_tensor(
masked_tensor,
thresh_b=thresh_b,
thresh_v=thresh_v
)
# Convert mask to PIL
mask_vis = (mask_np * 255).astype(np.uint8)
mask_pil = Image.fromarray(mask_vis).convert("L")
# Resize original for overlay
display_orig = pil.resize((300, 300))
# Create red mask overlay
red_mask = np.zeros((300, 300, 3), dtype=np.uint8)
red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2)
red_mask = Image.fromarray(red_mask).convert("RGBA")
alpha = (mask_np * 120).astype(np.uint8)
red_mask.putalpha(Image.fromarray(alpha))
blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
return (
pred_label,
float(pb),
float(pv),
masked_pil,
blended
)
title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
desc = (
"Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). "
"If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable."
)
examples_df = pd.DataFrame({
"Image": [
"images/NORMAL.jpeg",
"images/VIRAL.jpeg",
"images/BACT.jpeg",
],
"Label": ["NORMAL", "VIRAL", "BACTERIAL"]
})
with gr.Blocks(title=title) as demo:
gr.Markdown(f"### {title}")
gr.Markdown(desc)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="numpy",
label="Upload Chest X-ray"
)
thresh_b = gr.Slider(
minimum=0.1, maximum=0.9, step=0.01, value=0.5,
label="Bacterial Threshold"
)
thresh_v = gr.Slider(
minimum=0.1, maximum=0.9, step=0.01, value=0.5,
label="Viral Threshold"
)
seg_thresh = gr.Slider(
minimum=0.1, maximum=0.9, step=0.01, value=0.5,
label="Segmentation Mask Threshold"
)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
pred_label = gr.Label(num_top_classes=1, label="Prediction")
prob_b = gr.Number(label="Bacterial Probability")
prob_v = gr.Number(label="Viral Probability")
masked_img = gr.Image(type="pil", label="Masked Image")
seg_overlay = gr.Image(type="pil", label="Segmentation Overlay")
submit_btn.click(
fn=inference_pipeline,
inputs=[image_input, thresh_b, thresh_v, seg_thresh],
outputs=[pred_label, prob_b, prob_v, masked_img, seg_overlay]
)
clear_btn.click(
fn=lambda: (None, None, None, None, None),
inputs=None,
outputs=[pred_label, prob_b, prob_v, masked_img, seg_overlay]
)
with gr.Accordion("Try Examples", open=False):
examples_table = gr.Dataframe(
value=examples_df,
headers=["Image", "Label"],
datatype=["str", "str"],
interactive=False,
wrap=True,
height=300
)
load_btn = gr.Button("Load Selected Example")
def load_example(example_row):
if example_row is None or len(example_row) == 0:
return None
img_path = example_row[0]["Image"]
return img_path
load_btn.click(
fn=load_example,
inputs=examples_table,
outputs=image_input
)
if __name__ == "__main__":
demo.launch(share=False)