Clocksp's picture
Update app.py
2175e55 verified
raw
history blame
9.36 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image, ImageOps
import numpy as np
import gradio as gr
import os
import base64
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_op = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv_op(x)
class Downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = DoubleConv(in_channels, out_channels)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
down = self.conv(x)
p = self.pool(down)
return down, p
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# handle spatial mismatches
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=1):
super().__init__()
self.down1 = Downsample(in_channels, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.down4 = Downsample(256, 512)
self.bottleneck = DoubleConv(512, 1024)
self.up1 = UpSample(1024, 512)
self.up2 = UpSample(512, 256)
self.up3 = UpSample(256, 128)
self.up4 = UpSample(128, 64)
self.out = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x):
d1, p1 = self.down1(x)
d2, p2 = self.down2(p1)
d3, p3 = self.down3(p2)
d4, p4 = self.down4(p3)
b = self.bottleneck(p4)
u1 = self.up1(b, d4)
u2 = self.up2(u1, d3)
u3 = self.up3(u2, d2)
u4 = self.up4(u3, d1)
return self.out(u4)
def build_efficientnet_b3(num_output=2, pretrained=False):
# torchvision efficientnet_b3; weights=None or pretrained control
model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1)
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, num_output)
return model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
UNET_PATH = "models/unet.pth"
MODEL_BACT_PATH = "models/model_bacterial.pt"
MODEL_VIRAL_PATH = "models/model_viral.pt"
unet = UNet(in_channels=3, num_classes=1).to(device)
unet.load_state_dict(torch.load(UNET_PATH, map_location=device))
unet.eval()
model_bact = build_efficientnet_b3(num_output=2).to(device)
model_viral = build_efficientnet_b3(num_output=2).to(device)
model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device))
model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device))
model_bact.eval()
model_viral.eval()
preprocess_unet = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
])
preprocess_classifier = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
def infer_mask_and_mask_image(pil_img, threshold=0.5):
"""
Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL)
"""
# Ensure RGB
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
# UNet input: tensor
inp = preprocess_unet(pil_img).unsqueeze(0).to(device)
with torch.no_grad():
logits = unet(inp)
mask_prob = torch.sigmoid(logits)[0,0]
mask_np = mask_prob.cpu().numpy()
# binary mask
bin_mask = (mask_np >= threshold).astype(np.uint8)
# apply mask to original image (resized to 300x300) for classifier
img_tensor = preprocess_classifier(pil_img).to(device) # normalized
# the mask corresponds to preprocess_unet size (300,300) same as classifier
mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float()
masked_img_tensor = img_tensor * mask_tensor
# convert masked tensor back to PIL for display (unnormalize)
img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0)
masked_display = (img_for_display * bin_mask[...,None])
masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8)
masked_pil = Image.fromarray(masked_display)
return masked_img_tensor, mask_np, masked_pil
def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
"""
masked_img_tensor: C,H,W on device, normalized for classifier
Returns (pb, pv, label)
pb = probability pneumonia in bacterial model
pv = probability pneumonia in viral model
"""
x = masked_img_tensor.unsqueeze(0).to(device)
with torch.no_grad():
out_b = model_bact(x)
out_v = model_viral(x)
pb = torch.softmax(out_b, dim=1)[0,1].item()
pv = torch.softmax(out_v, dim=1)[0,1].item()
# ----------- DECISION LOGIC -----------
# Case 1: Both low → NORMAL
if pb < thresh_b and pv < thresh_v:
label = "NORMAL"
# Case 2: Only bacterial high → BACTERIAL
elif pb >= thresh_b and pv < thresh_v:
label = "BACTERIAL PNEUMONIA"
# Case 3: Only viral high → VIRAL
elif pv >= thresh_v and pb < thresh_b:
label = "VIRAL PNEUMONIA"
# Case 4: Both high → pick the dominant type
else:
label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA"
label += " (fallback-high-confidence-overlap)"
return pb, pv, label
def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
"""
Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL)
"""
pil = Image.fromarray(img.astype('uint8'), 'RGB')
masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(
pil, threshold=seg_thresh
)
pb, pv, pred_label = classify_masked_tensor(
masked_tensor,
thresh_b=thresh_b,
thresh_v=thresh_v
)
# Convert mask to PIL
mask_vis = (mask_np * 255).astype(np.uint8)
mask_pil = Image.fromarray(mask_vis).convert("L")
# Resize original for overlay
display_orig = pil.resize((300, 300))
# Create red mask overlay
red_mask = np.zeros((300, 300, 3), dtype=np.uint8)
red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2)
red_mask = Image.fromarray(red_mask).convert("RGBA")
alpha = (mask_np * 120).astype(np.uint8)
red_mask.putalpha(Image.fromarray(alpha))
blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
return (
pred_label,
float(pb),
float(pv),
masked_pil,
blended
)
title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
desc = (
"Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). "
"If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable."
)
example_samples = [
["images/NORMAL.jpeg", "NORMAL"],
["images/VIRAL.jpeg", "VIRAL"],
["images/BACT.jpeg", "BACTERIAL"],
]
with gr.Blocks(title=title) as demo:
gr.Markdown(f"### {title}")
gr.Markdown(desc)
iface = gr.Interface(
fn=inference_pipeline,
inputs=[
gr.Image(type="numpy", label="Upload chest X-ray (RGB or grayscale)"),
gr.Slider(minimum=0.1, maximum=0.9, step=0.01, value=0.5,
label="Bacterial threshold (thresh_b)"),
gr.Slider(minimum=0.1, maximum=0.9, step=0.01, value=0.5,
label="Viral threshold (thresh_v)"),
gr.Slider(minimum=0.1, maximum=0.9, step=0.01, value=0.5,
label="Segmentation mask threshold (seg_thresh)")
],
outputs=[
gr.Label(num_top_classes=1, label="Prediction"),
gr.Number(label="Bacterial Probability"),
gr.Number(label="Viral Probability"),
gr.Image(type="pil", label="Masked Image (input × mask)"),
gr.Image(type="pil", label="Segmentation Overlay (red mask)")
],
allow_flagging="never"
)
gr.Markdown("### Test Samples")
gr.Dataframe(
headers=["Image", "Label"],
value=example_samples,
datatype=["image", "str"],
interactive=False,
row_count=(len(example_samples), "fixed"),
col_count=(2, "fixed")
)
if __name__ == "__main__":
demo.launch(share=False)