import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms, models from PIL import Image, ImageOps import numpy as np import gradio as gr import os import pandas as pd class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv_op = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv_op(x) class Downsample(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = DoubleConv(in_channels, out_channels) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): down = self.conv(x) p = self.pool(down) return down, p class UpSample(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # handle spatial mismatches diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim=1) return self.conv(x) class UNet(nn.Module): def __init__(self, in_channels=3, num_classes=1): super().__init__() self.down1 = Downsample(in_channels, 64) self.down2 = Downsample(64, 128) self.down3 = Downsample(128, 256) self.down4 = Downsample(256, 512) self.bottleneck = DoubleConv(512, 1024) self.up1 = UpSample(1024, 512) self.up2 = UpSample(512, 256) self.up3 = UpSample(256, 128) self.up4 = UpSample(128, 64) self.out = nn.Conv2d(64, num_classes, kernel_size=1) def forward(self, x): d1, p1 = self.down1(x) d2, p2 = self.down2(p1) d3, p3 = self.down3(p2) d4, p4 = self.down4(p3) b = self.bottleneck(p4) u1 = self.up1(b, d4) u2 = self.up2(u1, d3) u3 = self.up3(u2, d2) u4 = self.up4(u3, d1) return self.out(u4) def build_efficientnet_b3(num_output=2, pretrained=False): # torchvision efficientnet_b3; weights=None or pretrained control model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1) in_features = model.classifier[1].in_features model.classifier[1] = nn.Linear(in_features, num_output) return model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) UNET_PATH = "models/unet.pth" MODEL_BACT_PATH = "models/model_bacterial.pt" MODEL_VIRAL_PATH = "models/model_viral.pt" unet = UNet(in_channels=3, num_classes=1).to(device) unet.load_state_dict(torch.load(UNET_PATH, map_location=device)) unet.eval() model_bact = build_efficientnet_b3(num_output=2).to(device) model_viral = build_efficientnet_b3(num_output=2).to(device) model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device)) model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device)) model_bact.eval() model_viral.eval() preprocess_unet = transforms.Compose([ transforms.Resize((300, 300)), transforms.ToTensor(), ]) preprocess_classifier = transforms.Compose([ transforms.Resize((300, 300)), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ]) def infer_mask_and_mask_image(pil_img, threshold=0.5): """ Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL) """ # Ensure RGB if pil_img.mode != "RGB": pil_img = pil_img.convert("RGB") # UNet input: tensor inp = preprocess_unet(pil_img).unsqueeze(0).to(device) with torch.no_grad(): logits = unet(inp) mask_prob = torch.sigmoid(logits)[0,0] mask_np = mask_prob.cpu().numpy() # binary mask bin_mask = (mask_np >= threshold).astype(np.uint8) # apply mask to original image (resized to 300x300) for classifier img_tensor = preprocess_classifier(pil_img).to(device) # normalized # the mask corresponds to preprocess_unet size (300,300) same as classifier mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float() masked_img_tensor = img_tensor * mask_tensor # convert masked tensor back to PIL for display (unnormalize) img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0) masked_display = (img_for_display * bin_mask[...,None]) masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8) masked_pil = Image.fromarray(masked_display) return masked_img_tensor, mask_np, masked_pil def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5): """ masked_img_tensor: C,H,W on device, normalized for classifier Returns (pb, pv, label) pb = probability pneumonia in bacterial model pv = probability pneumonia in viral model """ x = masked_img_tensor.unsqueeze(0).to(device) with torch.no_grad(): out_b = model_bact(x) out_v = model_viral(x) pb = torch.softmax(out_b, dim=1)[0,1].item() pv = torch.softmax(out_v, dim=1)[0,1].item() # ----------- DECISION LOGIC ----------- # Case 1: Both low → NORMAL if pb < thresh_b and pv < thresh_v: label = "NORMAL" # Case 2: Only bacterial high → BACTERIAL elif pb >= thresh_b and pv < thresh_v: label = "BACTERIAL PNEUMONIA" # Case 3: Only viral high → VIRAL elif pv >= thresh_v and pb < thresh_b: label = "VIRAL PNEUMONIA" # Case 4: Both high → pick the dominant type else: label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA" label += " (fallback-high-confidence-overlap)" return pb, pv, label def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5): """ Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL) """ pil = Image.fromarray(img.astype('uint8'), 'RGB') masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image( pil, threshold=seg_thresh ) pb, pv, pred_label = classify_masked_tensor( masked_tensor, thresh_b=thresh_b, thresh_v=thresh_v ) # Convert mask to PIL mask_vis = (mask_np * 255).astype(np.uint8) mask_pil = Image.fromarray(mask_vis).convert("L") # Resize original for overlay display_orig = pil.resize((300, 300)) # Create red mask overlay red_mask = np.zeros((300, 300, 3), dtype=np.uint8) red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2) red_mask = Image.fromarray(red_mask).convert("RGBA") alpha = (mask_np * 120).astype(np.uint8) red_mask.putalpha(Image.fromarray(alpha)) blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask) return ( pred_label, float(pb), float(pv), masked_pil, blended ) title = "Chest X-ray: UNet segmentation + 2 binary classifiers" desc = "Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). " \ "If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable." with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}\n{desc}") with gr.Row(): with gr.Column(): img_in = gr.Image(type="numpy", label="Upload chest X-ray") thresh_b = gr.Slider(0.1, 0.9, 0.5, step=0.01, label="Bacterial threshold") thresh_v = gr.Slider(0.1, 0.9, 0.5, step=0.01, label="Viral threshold") seg_thresh = gr.Slider(0.1, 0.9, 0.5, step=0.01, label="Segmentation mask threshold") submit_btn = gr.Button("Submit", variant="primary") clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(): pred_out = gr.Label(num_top_classes=1, label="Prediction") pb_out = gr.Number(label="Bacterial Probability") pv_out = gr.Number(label="Viral Probability") masked_img_out = gr.Image(type="pil", label="Masked Image") overlay_out = gr.Image(type="pil", label="Segmentation Overlay") submit_btn.click( inference_pipeline, inputs=[img_in, thresh_b, thresh_v, seg_thresh], outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out] ) clear_btn.click( lambda: (None, None, None, None, None, None), outputs=[img_in, pred_out, pb_out, pv_out, masked_img_out, overlay_out] ) img_in.change( inference_pipeline, inputs=[img_in, thresh_b, thresh_v, seg_thresh], outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out] ) gr.Markdown("## Test Samples") # with gr.Row(): # # NORMAL # with gr.Column(scale=1): # gr.Markdown("### NORMAL") # gr.Image("images/NORMAL.jpeg", show_label=False, height=220) # # VIRAL # with gr.Column(scale=1): # gr.Markdown("### VIRAL") # gr.Image("images/VIRAL.jpeg", show_label=False, height=220) # # BACTERIAL # with gr.Column(scale=1): # gr.Markdown("### BACTERIAL") # gr.Image("images/BACT.jpeg", show_label=False, height=220) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### NORMAL") normal_sample = gr.Image("images/NORMAL.jpeg", show_label=False, height=220, interactive=True) with gr.Column(scale=1): gr.Markdown("### VIRAL") viral_sample = gr.Image("images/VIRAL.jpeg", show_label=False, height=220, interactive=True) with gr.Column(scale=1): gr.Markdown("### BACTERIAL") bact_sample = gr.Image("images/BACT.jpeg", show_label=False, height=220, interactive=True) normal_sample.select( inference_pipeline, inputs=[normal_sample, thresh_b, thresh_v, seg_thresh], outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out] ) viral_sample.select( inference_pipeline, inputs=[viral_sample, thresh_b, thresh_v, seg_thresh], outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out] ) bact_sample.select( inference_pipeline, inputs=[bact_sample, thresh_b, thresh_v, seg_thresh], outputs=[pred_out, pb_out, pv_out, masked_img_out, overlay_out] ) if __name__ == "__main__": demo.launch()