| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms, models |
| from PIL import Image, ImageOps |
| import numpy as np |
| import gradio as gr |
| import os |
|
|
| class DoubleConv(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super().__init__() |
| self.conv_op = nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True) |
| ) |
| def forward(self, x): |
| return self.conv_op(x) |
|
|
| class Downsample(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super().__init__() |
| self.conv = DoubleConv(in_channels, out_channels) |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
| def forward(self, x): |
| down = self.conv(x) |
| p = self.pool(down) |
| return down, p |
|
|
| class UpSample(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super().__init__() |
| self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) |
| self.conv = DoubleConv(in_channels, out_channels) |
| def forward(self, x1, x2): |
| x1 = self.up(x1) |
| |
| diffY = x2.size()[2] - x1.size()[2] |
| diffX = x2.size()[3] - x1.size()[3] |
| x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, |
| diffY // 2, diffY - diffY // 2]) |
| x = torch.cat([x2, x1], dim=1) |
| return self.conv(x) |
|
|
| class UNet(nn.Module): |
| def __init__(self, in_channels=3, num_classes=1): |
| super().__init__() |
| self.down1 = Downsample(in_channels, 64) |
| self.down2 = Downsample(64, 128) |
| self.down3 = Downsample(128, 256) |
| self.down4 = Downsample(256, 512) |
| self.bottleneck = DoubleConv(512, 1024) |
| self.up1 = UpSample(1024, 512) |
| self.up2 = UpSample(512, 256) |
| self.up3 = UpSample(256, 128) |
| self.up4 = UpSample(128, 64) |
| self.out = nn.Conv2d(64, num_classes, kernel_size=1) |
| def forward(self, x): |
| d1, p1 = self.down1(x) |
| d2, p2 = self.down2(p1) |
| d3, p3 = self.down3(p2) |
| d4, p4 = self.down4(p3) |
| b = self.bottleneck(p4) |
| u1 = self.up1(b, d4) |
| u2 = self.up2(u1, d3) |
| u3 = self.up3(u2, d2) |
| u4 = self.up4(u3, d1) |
| return self.out(u4) |
|
|
|
|
| def build_efficientnet_b3(num_output=2, pretrained=False): |
| |
| model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1) |
| in_features = model.classifier[1].in_features |
| model.classifier[1] = nn.Linear(in_features, num_output) |
| return model |
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print("Using device:", device) |
|
|
| UNET_PATH = "models/unet.pth" |
| MODEL_BACT_PATH = "models/model_bacterial.pt" |
| MODEL_VIRAL_PATH = "models/model_viral.pt" |
|
|
|
|
| unet = UNet(in_channels=3, num_classes=1).to(device) |
| unet.load_state_dict(torch.load(UNET_PATH, map_location=device)) |
| unet.eval() |
|
|
|
|
| model_bact = build_efficientnet_b3(num_output=2).to(device) |
| model_viral = build_efficientnet_b3(num_output=2).to(device) |
|
|
| model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device)) |
| model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device)) |
|
|
| model_bact.eval() |
| model_viral.eval() |
|
|
|
|
| preprocess_unet = transforms.Compose([ |
| transforms.Resize((300, 300)), |
| transforms.ToTensor(), |
| ]) |
|
|
| preprocess_classifier = transforms.Compose([ |
| transforms.Resize((300, 300)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) |
| ]) |
|
|
| def infer_mask_and_mask_image(pil_img, threshold=0.5): |
| """ |
| Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL) |
| """ |
| |
| if pil_img.mode != "RGB": |
| pil_img = pil_img.convert("RGB") |
| |
| inp = preprocess_unet(pil_img).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| logits = unet(inp) |
| mask_prob = torch.sigmoid(logits)[0,0] |
| mask_np = mask_prob.cpu().numpy() |
| |
| bin_mask = (mask_np >= threshold).astype(np.uint8) |
| |
| img_tensor = preprocess_classifier(pil_img).to(device) |
| |
| mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float() |
| masked_img_tensor = img_tensor * mask_tensor |
| |
| img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0) |
| masked_display = (img_for_display * bin_mask[...,None]) |
| masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8) |
| masked_pil = Image.fromarray(masked_display) |
| return masked_img_tensor, mask_np, masked_pil |
|
|
| def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5): |
| """ |
| masked_img_tensor: C,H,W on device, normalized for classifier |
| Returns (pb, pv, label) |
| pb = probability pneumonia in bacterial model |
| pv = probability pneumonia in viral model |
| """ |
| x = masked_img_tensor.unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| out_b = model_bact(x) |
| out_v = model_viral(x) |
|
|
| pb = torch.softmax(out_b, dim=1)[0,1].item() |
| pv = torch.softmax(out_v, dim=1)[0,1].item() |
|
|
| |
| |
| if pb < thresh_b and pv < thresh_v: |
| label = "NORMAL" |
|
|
| |
| elif pb >= thresh_b and pv < thresh_v: |
| label = "BACTERIAL PNEUMONIA" |
|
|
| |
| elif pv >= thresh_v and pb < thresh_b: |
| label = "VIRAL PNEUMONIA" |
|
|
| |
| else: |
| label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA" |
| label += " (fallback-high-confidence-overlap)" |
|
|
| return pb, pv, label |
|
|
|
|
|
|
| def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5): |
| """ |
| Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL) |
| """ |
|
|
| pil = Image.fromarray(img.astype('uint8'), 'RGB') |
|
|
| masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image( |
| pil, threshold=seg_thresh |
| ) |
|
|
| pb, pv, pred_label = classify_masked_tensor( |
| masked_tensor, |
| thresh_b=thresh_b, |
| thresh_v=thresh_v |
| ) |
|
|
| |
| mask_vis = (mask_np * 255).astype(np.uint8) |
| mask_pil = Image.fromarray(mask_vis).convert("L") |
|
|
| |
| display_orig = pil.resize((300, 300)) |
|
|
| |
| red_mask = np.zeros((300, 300, 3), dtype=np.uint8) |
| red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2) |
| red_mask = Image.fromarray(red_mask).convert("RGBA") |
|
|
| alpha = (mask_np * 120).astype(np.uint8) |
| red_mask.putalpha(Image.fromarray(alpha)) |
|
|
| blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask) |
|
|
|
|
| return ( |
| pred_label, |
| float(pb), |
| float(pv), |
| masked_pil, |
| blended |
| ) |
|
|
|
|
| title = "Chest X-ray: UNet segmentation + 2 binary classifiers" |
| desc = "Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). " \ |
| "If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable." |
|
|
| iface = gr.Interface( |
| fn=inference_pipeline, |
| inputs=[ |
| gr.Image(type="numpy", label="Upload chest X-ray (RGB or grayscale)"), |
| gr.Slider(minimum=0.1, maximum=0.9, step=0.01, value=0.5, label="Bacterial threshold (thresh_b)"), |
| gr.Slider(minimum=0.1, maximum=0.9, step=0.01, value=0.5, label="Viral threshold (thresh_v)"), |
| gr.Slider(minimum=0.1, maximum=0.9, step=0.01, value=0.5, label="Segmentation mask threshold (seg_thresh)") |
| ], |
| outputs=[ |
| gr.Label(num_top_classes=1, label="Prediction"), |
| gr.Number(label="Bacterial Probability"), |
| gr.Number(label="Viral Probability"), |
| gr.Image(type="pil", label="Masked Image (input × mask)"), |
| gr.Image(type="pil", label="Segmentation Overlay (red mask)") |
| ], |
| title=title, |
| description=desc, |
| allow_flagging="never" |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch() |
|
|