Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
import torch.nn.functional as F
|
| 3 |
import timm
|
| 4 |
from torchvision import transforms
|
|
@@ -6,14 +7,15 @@ from PIL import Image
|
|
| 6 |
import numpy as np
|
| 7 |
import cv2
|
| 8 |
import gradio as gr
|
|
|
|
| 9 |
|
| 10 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
MODEL_PATH = "best_timm_pneumonia.pt"
|
|
|
|
| 12 |
IMG_SIZE = 384
|
| 13 |
|
| 14 |
-
|
| 15 |
# ---------------------------------------------------------
|
| 16 |
-
#
|
| 17 |
# ---------------------------------------------------------
|
| 18 |
val_tf = transforms.Compose([
|
| 19 |
transforms.Grayscale(num_output_channels=3),
|
|
@@ -23,9 +25,8 @@ val_tf = transforms.Compose([
|
|
| 23 |
transforms.Normalize([0.485]*3, [0.229]*3),
|
| 24 |
])
|
| 25 |
|
| 26 |
-
|
| 27 |
# ---------------------------------------------------------
|
| 28 |
-
#
|
| 29 |
# ---------------------------------------------------------
|
| 30 |
def build_model():
|
| 31 |
return timm.create_model("tf_efficientnet_b0_ns", pretrained=False, num_classes=1)
|
|
@@ -34,34 +35,106 @@ model = build_model().to(DEVICE)
|
|
| 34 |
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
| 35 |
model.eval()
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
# ---------------------------------------------------------
|
| 39 |
-
#
|
| 40 |
# ---------------------------------------------------------
|
| 41 |
def autocrop_chest(pil_img, th=10):
|
| 42 |
gray = np.array(pil_img.convert("L"))
|
| 43 |
mask = gray > th
|
| 44 |
-
if mask.sum() < 10:
|
| 45 |
-
return pil_img
|
| 46 |
ys, xs = np.where(mask)
|
| 47 |
return Image.fromarray(gray[ys.min():ys.max(), xs.min():xs.max()]).convert("RGB")
|
| 48 |
|
| 49 |
-
|
| 50 |
# ---------------------------------------------------------
|
| 51 |
-
#
|
| 52 |
# ---------------------------------------------------------
|
| 53 |
-
def
|
| 54 |
img = np.array(pil_img.convert("L"))
|
| 55 |
-
|
| 56 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 57 |
g = clahe.apply(img)
|
| 58 |
-
|
| 59 |
g = cv2.GaussianBlur(g, (5, 5), 0)
|
| 60 |
_, th = cv2.threshold(g, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 61 |
-
|
| 62 |
th = cv2.morphologyEx(th, cv2.MORPH_OPEN, np.ones((5, 5), np.uint8), iterations=1)
|
| 63 |
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, np.ones((15, 15), np.uint8), iterations=1)
|
| 64 |
-
|
| 65 |
num, labels, stats, _ = cv2.connectedComponentsWithStats(th, connectivity=8)
|
| 66 |
if num > 1:
|
| 67 |
areas = [(i, stats[i, cv2.CC_STAT_AREA]) for i in range(1, num)]
|
|
@@ -70,79 +143,76 @@ def get_lung_mask(pil_img):
|
|
| 70 |
mask = np.where(np.isin(labels, list(keep)), 1.0, 0.0).astype(np.float32)
|
| 71 |
else:
|
| 72 |
mask = (th > 0).astype(np.float32)
|
| 73 |
-
|
| 74 |
mask = cv2.dilate(mask, np.ones((11, 11), np.uint8), iterations=1)
|
| 75 |
mask = cv2.GaussianBlur(mask, (21, 21), 0)
|
| 76 |
-
|
| 77 |
h, w = mask.shape
|
| 78 |
b = int(0.06 * w)
|
| 79 |
mask[:b, :] = mask[-b:, :] = mask[:, :b] = mask[:, -b:] = 0
|
|
|
|
|
|
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
|
| 82 |
return mask
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# ---------------------------------------------------------
|
| 86 |
-
#
|
| 87 |
# ---------------------------------------------------------
|
| 88 |
def get_gradcam(model, img_tensor):
|
| 89 |
target_layer = model.blocks[-1]
|
| 90 |
-
|
| 91 |
activ, grads = [], []
|
| 92 |
-
|
| 93 |
-
def
|
| 94 |
-
activ.append(o)
|
| 95 |
-
|
| 96 |
-
def bwd(m, gi, go):
|
| 97 |
-
grads.append(go[0])
|
| 98 |
-
|
| 99 |
h1 = target_layer.register_forward_hook(fwd)
|
| 100 |
h2 = target_layer.register_backward_hook(bwd)
|
| 101 |
-
|
| 102 |
logits = model(img_tensor)
|
| 103 |
score = torch.sigmoid(logits)[0]
|
| 104 |
-
model.zero_grad()
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
h1.remove()
|
| 108 |
-
h2.remove()
|
| 109 |
-
|
| 110 |
A = activ[0][0].detach().cpu().numpy()
|
| 111 |
G = grads[0].detach().cpu().numpy()
|
| 112 |
-
|
| 113 |
weights = G.mean(axis=(1, 2))
|
| 114 |
cam = np.zeros(A.shape[1:], dtype=np.float32)
|
| 115 |
-
for c, w in enumerate(weights):
|
| 116 |
-
cam += w * A[c]
|
| 117 |
-
|
| 118 |
cam = np.maximum(cam, 0)
|
| 119 |
-
cam -= cam.min()
|
| 120 |
-
cam /= (cam.max() + 1e-8)
|
| 121 |
-
|
| 122 |
cam = cv2.GaussianBlur(cam, (7, 7), 0)
|
| 123 |
return cam
|
| 124 |
|
| 125 |
-
|
| 126 |
# ---------------------------------------------------------
|
| 127 |
-
#
|
| 128 |
# ---------------------------------------------------------
|
| 129 |
def enhance_cam(cam, mode):
|
| 130 |
-
if mode == "Weak":
|
| 131 |
-
|
| 132 |
-
elif mode == "Strong":
|
| 133 |
-
return cam ** 0.6
|
| 134 |
return cam ** 0.9 # Medium
|
| 135 |
|
| 136 |
-
|
| 137 |
# ---------------------------------------------------------
|
| 138 |
-
#
|
| 139 |
# ---------------------------------------------------------
|
| 140 |
def predict(img, mode, intensity, threshold):
|
| 141 |
-
|
| 142 |
-
# Preprocess
|
| 143 |
cropped = autocrop_chest(img)
|
| 144 |
resized = cropped.resize((IMG_SIZE, IMG_SIZE))
|
| 145 |
-
|
| 146 |
x = val_tf(resized).unsqueeze(0).to(DEVICE)
|
| 147 |
|
| 148 |
with torch.no_grad():
|
|
@@ -154,42 +224,35 @@ def predict(img, mode, intensity, threshold):
|
|
| 154 |
# Grad-CAM
|
| 155 |
cam = get_gradcam(model, x)
|
| 156 |
cam = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
|
| 157 |
-
|
| 158 |
-
# Soft threshold
|
| 159 |
cam = np.where(cam >= threshold, cam, cam * 0.3)
|
| 160 |
cam = enhance_cam(cam, mode)
|
| 161 |
-
cam
|
| 162 |
|
| 163 |
-
# Heatmap
|
| 164 |
heat = cv2.applyColorMap((cam * 255).astype(np.uint8), cv2.COLORMAP_JET)[..., ::-1] / 255.0
|
| 165 |
-
|
| 166 |
-
# Convert both to float32 to avoid OpenCV errors
|
| 167 |
base = np.array(resized).astype(np.float32) / 255.0
|
| 168 |
base_f = base.astype(np.float32)
|
| 169 |
heat_f = heat.astype(np.float32)
|
| 170 |
|
| 171 |
-
#
|
| 172 |
overlay = cv2.addWeighted(base_f, 1 - intensity, heat_f, intensity, 0)
|
| 173 |
overlay = np.clip(overlay * 255, 0, 255).astype(np.uint8)
|
| 174 |
|
| 175 |
# Lung-masked CAM
|
| 176 |
lung_mask = get_lung_mask(resized)
|
| 177 |
-
masked = heat_f * lung_mask[..., None]
|
| 178 |
masked = np.clip(masked * 255, 0, 255).astype(np.uint8)
|
| 179 |
|
| 180 |
-
# Side-by-side results
|
| 181 |
combined = np.hstack([
|
| 182 |
(base * 255).astype(np.uint8),
|
| 183 |
overlay,
|
| 184 |
masked
|
| 185 |
])
|
| 186 |
-
|
| 187 |
-
text = f"Prediction: {label}\nP(PNEUMONIA)={prob_p:.3f}\nP(NORMAL)={prob_n:.3f}"
|
| 188 |
return text, combined
|
| 189 |
|
| 190 |
-
|
| 191 |
# ---------------------------------------------------------
|
| 192 |
-
#
|
| 193 |
# ---------------------------------------------------------
|
| 194 |
demo = gr.Interface(
|
| 195 |
fn=predict,
|
|
@@ -197,16 +260,14 @@ demo = gr.Interface(
|
|
| 197 |
gr.Image(type="pil", label="Upload Chest X-Ray"),
|
| 198 |
gr.Radio(["Weak", "Medium", "Strong"], value="Medium", label="Heatmap Mode"),
|
| 199 |
gr.Slider(0.1, 1.0, value=0.70, step=0.05, label="Heatmap Intensity"),
|
| 200 |
-
gr.Slider(0.0, 1.0, value=0.40, step=0.05, label="Heatmap Threshold")
|
| 201 |
],
|
| 202 |
outputs=[
|
| 203 |
-
gr.Text(label="Prediction"),
|
| 204 |
-
gr.Image(label="Original | Heatmap | Lung-Masked CAM")
|
| 205 |
],
|
| 206 |
-
title="Pneumonia
|
| 207 |
-
description=
|
| 208 |
-
"Upload an X-ray to view pneumonia predictions with heatmaps and lung-masked Grad-CAM."
|
| 209 |
-
)
|
| 210 |
)
|
| 211 |
|
| 212 |
demo.launch()
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import timm
|
| 5 |
from torchvision import transforms
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import cv2
|
| 9 |
import gradio as gr
|
| 10 |
+
import os
|
| 11 |
|
| 12 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
MODEL_PATH = "best_timm_pneumonia.pt"
|
| 14 |
+
SEG_PATHS = ["lung_unet_lite.pt", "lung_unet_lite.ts"] # accepted filenames
|
| 15 |
IMG_SIZE = 384
|
| 16 |
|
|
|
|
| 17 |
# ---------------------------------------------------------
|
| 18 |
+
# Preprocessing
|
| 19 |
# ---------------------------------------------------------
|
| 20 |
val_tf = transforms.Compose([
|
| 21 |
transforms.Grayscale(num_output_channels=3),
|
|
|
|
| 25 |
transforms.Normalize([0.485]*3, [0.229]*3),
|
| 26 |
])
|
| 27 |
|
|
|
|
| 28 |
# ---------------------------------------------------------
|
| 29 |
+
# Classifier (EfficientNet-B0)
|
| 30 |
# ---------------------------------------------------------
|
| 31 |
def build_model():
|
| 32 |
return timm.create_model("tf_efficientnet_b0_ns", pretrained=False, num_classes=1)
|
|
|
|
| 35 |
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
| 36 |
model.eval()
|
| 37 |
|
| 38 |
+
# ---------------------------------------------------------
|
| 39 |
+
# Lightweight Lung U-Net (architecture only; ~5MB weights expected)
|
| 40 |
+
# ---------------------------------------------------------
|
| 41 |
+
class DSConv(nn.Module):
|
| 42 |
+
def __init__(self, c_in, c_out):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.dw = nn.Conv2d(c_in, c_in, 3, padding=1, groups=c_in, bias=False)
|
| 45 |
+
self.pw = nn.Conv2d(c_in, c_out, 1, bias=False)
|
| 46 |
+
self.bn = nn.BatchNorm2d(c_out)
|
| 47 |
+
self.act = nn.ReLU(inplace=True)
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = self.dw(x); x = self.pw(x); x = self.bn(x); return self.act(x)
|
| 50 |
+
|
| 51 |
+
class Block(nn.Module):
|
| 52 |
+
def __init__(self, c_in, c_out):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.c1 = DSConv(c_in, c_out)
|
| 55 |
+
self.c2 = DSConv(c_out, c_out)
|
| 56 |
+
def forward(self, x): return self.c2(self.c1(x))
|
| 57 |
+
|
| 58 |
+
class Up(nn.Module):
|
| 59 |
+
def __init__(self, c_in, c_out):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.up = nn.ConvTranspose2d(c_in, c_in//2, 2, stride=2)
|
| 62 |
+
self.conv = Block(c_in, c_out)
|
| 63 |
+
def forward(self, x, skip):
|
| 64 |
+
x = self.up(x)
|
| 65 |
+
# pad if sizes mismatch
|
| 66 |
+
dh, dw = skip.shape[2]-x.shape[2], skip.shape[3]-x.shape[3]
|
| 67 |
+
x = F.pad(x, [dw//2, dw-dw//2, dh//2, dh-dh//2])
|
| 68 |
+
x = torch.cat([skip, x], dim=1)
|
| 69 |
+
return self.conv(x)
|
| 70 |
+
|
| 71 |
+
class LungUNetLite(nn.Module):
|
| 72 |
+
def __init__(self):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.enc1 = Block(1, 32)
|
| 75 |
+
self.enc2 = Block(32, 64)
|
| 76 |
+
self.enc3 = Block(64, 128)
|
| 77 |
+
self.enc4 = Block(128, 256)
|
| 78 |
+
self.pool = nn.MaxPool2d(2)
|
| 79 |
+
self.bott = Block(256, 256)
|
| 80 |
+
self.up3 = Up(256, 128)
|
| 81 |
+
self.up2 = Up(128, 64)
|
| 82 |
+
self.up1 = Up(64, 32)
|
| 83 |
+
self.outc = nn.Conv2d(32, 1, 1)
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
e1 = self.enc1(x) # 1/1
|
| 87 |
+
e2 = self.enc2(self.pool(e1)) # 1/2
|
| 88 |
+
e3 = self.enc3(self.pool(e2)) # 1/4
|
| 89 |
+
e4 = self.enc4(self.pool(e3)) # 1/8
|
| 90 |
+
b = self.bott(self.pool(e4)) # 1/16
|
| 91 |
+
x = self.up3(b, e4)
|
| 92 |
+
x = self.up2(x, e3)
|
| 93 |
+
x = self.up1(x, e2)
|
| 94 |
+
x = self.outc(x)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
# Try to load a tiny lung segmentation model if the user uploaded it
|
| 98 |
+
lung_net = None
|
| 99 |
+
seg_loaded_msg = ""
|
| 100 |
+
for p in SEG_PATHS:
|
| 101 |
+
if os.path.exists(p):
|
| 102 |
+
try:
|
| 103 |
+
if p.endswith(".ts"): # TorchScript
|
| 104 |
+
lung_net = torch.jit.load(p, map_location=DEVICE)
|
| 105 |
+
else: # state_dict
|
| 106 |
+
lung_net = LungUNetLite().to(DEVICE)
|
| 107 |
+
lung_net.load_state_dict(torch.load(p, map_location=DEVICE))
|
| 108 |
+
lung_net.eval()
|
| 109 |
+
seg_loaded_msg = f"✅ Lung model loaded: {p}"
|
| 110 |
+
break
|
| 111 |
+
except Exception as e:
|
| 112 |
+
seg_loaded_msg = f"⚠️ Failed to load {p}: {e}"
|
| 113 |
+
|
| 114 |
+
if lung_net is None and seg_loaded_msg == "":
|
| 115 |
+
seg_loaded_msg = "ℹ️ Using classical (non-ML) lung mask; upload 'lung_unet_lite.pt' to upgrade."
|
| 116 |
|
| 117 |
# ---------------------------------------------------------
|
| 118 |
+
# Auto-crop borders
|
| 119 |
# ---------------------------------------------------------
|
| 120 |
def autocrop_chest(pil_img, th=10):
|
| 121 |
gray = np.array(pil_img.convert("L"))
|
| 122 |
mask = gray > th
|
| 123 |
+
if mask.sum() < 10: return pil_img
|
|
|
|
| 124 |
ys, xs = np.where(mask)
|
| 125 |
return Image.fromarray(gray[ys.min():ys.max(), xs.min():xs.max()]).convert("RGB")
|
| 126 |
|
|
|
|
| 127 |
# ---------------------------------------------------------
|
| 128 |
+
# Classical lung mask (fallback, no ML)
|
| 129 |
# ---------------------------------------------------------
|
| 130 |
+
def classical_lung_mask(pil_img):
|
| 131 |
img = np.array(pil_img.convert("L"))
|
|
|
|
| 132 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 133 |
g = clahe.apply(img)
|
|
|
|
| 134 |
g = cv2.GaussianBlur(g, (5, 5), 0)
|
| 135 |
_, th = cv2.threshold(g, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
|
|
|
| 136 |
th = cv2.morphologyEx(th, cv2.MORPH_OPEN, np.ones((5, 5), np.uint8), iterations=1)
|
| 137 |
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, np.ones((15, 15), np.uint8), iterations=1)
|
|
|
|
| 138 |
num, labels, stats, _ = cv2.connectedComponentsWithStats(th, connectivity=8)
|
| 139 |
if num > 1:
|
| 140 |
areas = [(i, stats[i, cv2.CC_STAT_AREA]) for i in range(1, num)]
|
|
|
|
| 143 |
mask = np.where(np.isin(labels, list(keep)), 1.0, 0.0).astype(np.float32)
|
| 144 |
else:
|
| 145 |
mask = (th > 0).astype(np.float32)
|
|
|
|
| 146 |
mask = cv2.dilate(mask, np.ones((11, 11), np.uint8), iterations=1)
|
| 147 |
mask = cv2.GaussianBlur(mask, (21, 21), 0)
|
|
|
|
| 148 |
h, w = mask.shape
|
| 149 |
b = int(0.06 * w)
|
| 150 |
mask[:b, :] = mask[-b:, :] = mask[:, :b] = mask[:, -b:] = 0
|
| 151 |
+
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
|
| 152 |
+
return mask
|
| 153 |
|
| 154 |
+
# ---------------------------------------------------------
|
| 155 |
+
# ML lung mask (if weights provided)
|
| 156 |
+
# ---------------------------------------------------------
|
| 157 |
+
def ml_lung_mask(pil_img):
|
| 158 |
+
# expects grayscale 1xHxW scaled to 256 side for speed, then resize back
|
| 159 |
+
img = np.array(pil_img.convert("L"))
|
| 160 |
+
h, w = img.shape
|
| 161 |
+
side = 256
|
| 162 |
+
imr = cv2.resize(img, (side, side))
|
| 163 |
+
tens = torch.from_numpy(imr[None, None].astype(np.float32) / 255.0).to(DEVICE)
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
logits = lung_net(tens)
|
| 166 |
+
prob = torch.sigmoid(logits)[0, 0].cpu().numpy()
|
| 167 |
+
mask = cv2.resize(prob, (IMG_SIZE, IMG_SIZE))
|
| 168 |
+
# soft refine
|
| 169 |
+
mask = cv2.GaussianBlur(mask, (15, 15), 0)
|
| 170 |
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
|
| 171 |
return mask
|
| 172 |
|
| 173 |
+
def get_lung_mask(pil_img):
|
| 174 |
+
if lung_net is not None:
|
| 175 |
+
return ml_lung_mask(pil_img)
|
| 176 |
+
return classical_lung_mask(pil_img)
|
| 177 |
|
| 178 |
# ---------------------------------------------------------
|
| 179 |
+
# Soft Grad-CAM (correct layer)
|
| 180 |
# ---------------------------------------------------------
|
| 181 |
def get_gradcam(model, img_tensor):
|
| 182 |
target_layer = model.blocks[-1]
|
|
|
|
| 183 |
activ, grads = [], []
|
| 184 |
+
def fwd(m, i, o): activ.append(o)
|
| 185 |
+
def bwd(m, gi, go): grads.append(go[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
h1 = target_layer.register_forward_hook(fwd)
|
| 187 |
h2 = target_layer.register_backward_hook(bwd)
|
|
|
|
| 188 |
logits = model(img_tensor)
|
| 189 |
score = torch.sigmoid(logits)[0]
|
| 190 |
+
model.zero_grad(); score.backward()
|
| 191 |
+
h1.remove(); h2.remove()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
A = activ[0][0].detach().cpu().numpy()
|
| 193 |
G = grads[0].detach().cpu().numpy()
|
|
|
|
| 194 |
weights = G.mean(axis=(1, 2))
|
| 195 |
cam = np.zeros(A.shape[1:], dtype=np.float32)
|
| 196 |
+
for c, w in enumerate(weights): cam += w * A[c]
|
|
|
|
|
|
|
| 197 |
cam = np.maximum(cam, 0)
|
| 198 |
+
cam -= cam.min(); cam /= (cam.max() + 1e-8)
|
|
|
|
|
|
|
| 199 |
cam = cv2.GaussianBlur(cam, (7, 7), 0)
|
| 200 |
return cam
|
| 201 |
|
|
|
|
| 202 |
# ---------------------------------------------------------
|
| 203 |
+
# Mode enhancement
|
| 204 |
# ---------------------------------------------------------
|
| 205 |
def enhance_cam(cam, mode):
|
| 206 |
+
if mode == "Weak": return cam ** 1.4
|
| 207 |
+
if mode == "Strong": return cam ** 0.6
|
|
|
|
|
|
|
| 208 |
return cam ** 0.9 # Medium
|
| 209 |
|
|
|
|
| 210 |
# ---------------------------------------------------------
|
| 211 |
+
# Prediction + Visualization
|
| 212 |
# ---------------------------------------------------------
|
| 213 |
def predict(img, mode, intensity, threshold):
|
|
|
|
|
|
|
| 214 |
cropped = autocrop_chest(img)
|
| 215 |
resized = cropped.resize((IMG_SIZE, IMG_SIZE))
|
|
|
|
| 216 |
x = val_tf(resized).unsqueeze(0).to(DEVICE)
|
| 217 |
|
| 218 |
with torch.no_grad():
|
|
|
|
| 224 |
# Grad-CAM
|
| 225 |
cam = get_gradcam(model, x)
|
| 226 |
cam = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
|
|
|
|
|
|
|
| 227 |
cam = np.where(cam >= threshold, cam, cam * 0.3)
|
| 228 |
cam = enhance_cam(cam, mode)
|
| 229 |
+
cam = cam / (cam.max() + 1e-8)
|
| 230 |
|
| 231 |
+
# Heatmap + base (both float32 [0,1])
|
| 232 |
heat = cv2.applyColorMap((cam * 255).astype(np.uint8), cv2.COLORMAP_JET)[..., ::-1] / 255.0
|
|
|
|
|
|
|
| 233 |
base = np.array(resized).astype(np.float32) / 255.0
|
| 234 |
base_f = base.astype(np.float32)
|
| 235 |
heat_f = heat.astype(np.float32)
|
| 236 |
|
| 237 |
+
# Overlay
|
| 238 |
overlay = cv2.addWeighted(base_f, 1 - intensity, heat_f, intensity, 0)
|
| 239 |
overlay = np.clip(overlay * 255, 0, 255).astype(np.uint8)
|
| 240 |
|
| 241 |
# Lung-masked CAM
|
| 242 |
lung_mask = get_lung_mask(resized)
|
| 243 |
+
masked = (heat_f * lung_mask[..., None])
|
| 244 |
masked = np.clip(masked * 255, 0, 255).astype(np.uint8)
|
| 245 |
|
|
|
|
| 246 |
combined = np.hstack([
|
| 247 |
(base * 255).astype(np.uint8),
|
| 248 |
overlay,
|
| 249 |
masked
|
| 250 |
])
|
| 251 |
+
text = f"{seg_loaded_msg}\nPrediction: {label}\nP(PNEUMONIA)={prob_p:.3f} | P(NORMAL)={prob_n:.3f}"
|
|
|
|
| 252 |
return text, combined
|
| 253 |
|
|
|
|
| 254 |
# ---------------------------------------------------------
|
| 255 |
+
# UI
|
| 256 |
# ---------------------------------------------------------
|
| 257 |
demo = gr.Interface(
|
| 258 |
fn=predict,
|
|
|
|
| 260 |
gr.Image(type="pil", label="Upload Chest X-Ray"),
|
| 261 |
gr.Radio(["Weak", "Medium", "Strong"], value="Medium", label="Heatmap Mode"),
|
| 262 |
gr.Slider(0.1, 1.0, value=0.70, step=0.05, label="Heatmap Intensity"),
|
| 263 |
+
gr.Slider(0.0, 1.0, value=0.40, step=0.05, label="Heatmap Threshold"),
|
| 264 |
],
|
| 265 |
outputs=[
|
| 266 |
+
gr.Text(label="Status + Prediction"),
|
| 267 |
+
gr.Image(label="Original | Heatmap | Lung-Masked CAM"),
|
| 268 |
],
|
| 269 |
+
title="Pneumonia Detector (Soft Grad-CAM + Lung Mask, Lite)",
|
| 270 |
+
description="Upload an X-ray. For lung-only CAM, upload 'lung_unet_lite.pt' or 'lung_unet_lite.ts' to this Space."
|
|
|
|
|
|
|
| 271 |
)
|
| 272 |
|
| 273 |
demo.launch()
|