phoenix6238 commited on
Commit
55d1d28
·
verified ·
1 Parent(s): 035b08a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -70
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import torch.nn.functional as F
3
  import timm
4
  from torchvision import transforms
@@ -6,14 +7,15 @@ from PIL import Image
6
  import numpy as np
7
  import cv2
8
  import gradio as gr
 
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  MODEL_PATH = "best_timm_pneumonia.pt"
 
12
  IMG_SIZE = 384
13
 
14
-
15
  # ---------------------------------------------------------
16
- # Preprocessing
17
  # ---------------------------------------------------------
18
  val_tf = transforms.Compose([
19
  transforms.Grayscale(num_output_channels=3),
@@ -23,9 +25,8 @@ val_tf = transforms.Compose([
23
  transforms.Normalize([0.485]*3, [0.229]*3),
24
  ])
25
 
26
-
27
  # ---------------------------------------------------------
28
- # Load EfficientNet Model
29
  # ---------------------------------------------------------
30
  def build_model():
31
  return timm.create_model("tf_efficientnet_b0_ns", pretrained=False, num_classes=1)
@@ -34,34 +35,106 @@ model = build_model().to(DEVICE)
34
  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
35
  model.eval()
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # ---------------------------------------------------------
39
- # Auto-Crop Dark Borders
40
  # ---------------------------------------------------------
41
  def autocrop_chest(pil_img, th=10):
42
  gray = np.array(pil_img.convert("L"))
43
  mask = gray > th
44
- if mask.sum() < 10:
45
- return pil_img
46
  ys, xs = np.where(mask)
47
  return Image.fromarray(gray[ys.min():ys.max(), xs.min():xs.max()]).convert("RGB")
48
 
49
-
50
  # ---------------------------------------------------------
51
- # Classical Lung Mask (no models required)
52
  # ---------------------------------------------------------
53
- def get_lung_mask(pil_img):
54
  img = np.array(pil_img.convert("L"))
55
-
56
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
57
  g = clahe.apply(img)
58
-
59
  g = cv2.GaussianBlur(g, (5, 5), 0)
60
  _, th = cv2.threshold(g, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
61
-
62
  th = cv2.morphologyEx(th, cv2.MORPH_OPEN, np.ones((5, 5), np.uint8), iterations=1)
63
  th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, np.ones((15, 15), np.uint8), iterations=1)
64
-
65
  num, labels, stats, _ = cv2.connectedComponentsWithStats(th, connectivity=8)
66
  if num > 1:
67
  areas = [(i, stats[i, cv2.CC_STAT_AREA]) for i in range(1, num)]
@@ -70,79 +143,76 @@ def get_lung_mask(pil_img):
70
  mask = np.where(np.isin(labels, list(keep)), 1.0, 0.0).astype(np.float32)
71
  else:
72
  mask = (th > 0).astype(np.float32)
73
-
74
  mask = cv2.dilate(mask, np.ones((11, 11), np.uint8), iterations=1)
75
  mask = cv2.GaussianBlur(mask, (21, 21), 0)
76
-
77
  h, w = mask.shape
78
  b = int(0.06 * w)
79
  mask[:b, :] = mask[-b:, :] = mask[:, :b] = mask[:, -b:] = 0
 
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
82
  return mask
83
 
 
 
 
 
84
 
85
  # ---------------------------------------------------------
86
- # Soft Grad-CAM
87
  # ---------------------------------------------------------
88
  def get_gradcam(model, img_tensor):
89
  target_layer = model.blocks[-1]
90
-
91
  activ, grads = [], []
92
-
93
- def fwd(m, i, o):
94
- activ.append(o)
95
-
96
- def bwd(m, gi, go):
97
- grads.append(go[0])
98
-
99
  h1 = target_layer.register_forward_hook(fwd)
100
  h2 = target_layer.register_backward_hook(bwd)
101
-
102
  logits = model(img_tensor)
103
  score = torch.sigmoid(logits)[0]
104
- model.zero_grad()
105
- score.backward()
106
-
107
- h1.remove()
108
- h2.remove()
109
-
110
  A = activ[0][0].detach().cpu().numpy()
111
  G = grads[0].detach().cpu().numpy()
112
-
113
  weights = G.mean(axis=(1, 2))
114
  cam = np.zeros(A.shape[1:], dtype=np.float32)
115
- for c, w in enumerate(weights):
116
- cam += w * A[c]
117
-
118
  cam = np.maximum(cam, 0)
119
- cam -= cam.min()
120
- cam /= (cam.max() + 1e-8)
121
-
122
  cam = cv2.GaussianBlur(cam, (7, 7), 0)
123
  return cam
124
 
125
-
126
  # ---------------------------------------------------------
127
- # Visualization Modes
128
  # ---------------------------------------------------------
129
  def enhance_cam(cam, mode):
130
- if mode == "Weak":
131
- return cam ** 1.4
132
- elif mode == "Strong":
133
- return cam ** 0.6
134
  return cam ** 0.9 # Medium
135
 
136
-
137
  # ---------------------------------------------------------
138
- # Prediction + Visualization
139
  # ---------------------------------------------------------
140
  def predict(img, mode, intensity, threshold):
141
-
142
- # Preprocess
143
  cropped = autocrop_chest(img)
144
  resized = cropped.resize((IMG_SIZE, IMG_SIZE))
145
-
146
  x = val_tf(resized).unsqueeze(0).to(DEVICE)
147
 
148
  with torch.no_grad():
@@ -154,42 +224,35 @@ def predict(img, mode, intensity, threshold):
154
  # Grad-CAM
155
  cam = get_gradcam(model, x)
156
  cam = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
157
-
158
- # Soft threshold
159
  cam = np.where(cam >= threshold, cam, cam * 0.3)
160
  cam = enhance_cam(cam, mode)
161
- cam /= (cam.max() + 1e-8)
162
 
163
- # Heatmap
164
  heat = cv2.applyColorMap((cam * 255).astype(np.uint8), cv2.COLORMAP_JET)[..., ::-1] / 255.0
165
-
166
- # Convert both to float32 to avoid OpenCV errors
167
  base = np.array(resized).astype(np.float32) / 255.0
168
  base_f = base.astype(np.float32)
169
  heat_f = heat.astype(np.float32)
170
 
171
- # Blended overlay
172
  overlay = cv2.addWeighted(base_f, 1 - intensity, heat_f, intensity, 0)
173
  overlay = np.clip(overlay * 255, 0, 255).astype(np.uint8)
174
 
175
  # Lung-masked CAM
176
  lung_mask = get_lung_mask(resized)
177
- masked = heat_f * lung_mask[..., None]
178
  masked = np.clip(masked * 255, 0, 255).astype(np.uint8)
179
 
180
- # Side-by-side results
181
  combined = np.hstack([
182
  (base * 255).astype(np.uint8),
183
  overlay,
184
  masked
185
  ])
186
-
187
- text = f"Prediction: {label}\nP(PNEUMONIA)={prob_p:.3f}\nP(NORMAL)={prob_n:.3f}"
188
  return text, combined
189
 
190
-
191
  # ---------------------------------------------------------
192
- # ✅ Gradio UI
193
  # ---------------------------------------------------------
194
  demo = gr.Interface(
195
  fn=predict,
@@ -197,16 +260,14 @@ demo = gr.Interface(
197
  gr.Image(type="pil", label="Upload Chest X-Ray"),
198
  gr.Radio(["Weak", "Medium", "Strong"], value="Medium", label="Heatmap Mode"),
199
  gr.Slider(0.1, 1.0, value=0.70, step=0.05, label="Heatmap Intensity"),
200
- gr.Slider(0.0, 1.0, value=0.40, step=0.05, label="Heatmap Threshold")
201
  ],
202
  outputs=[
203
- gr.Text(label="Prediction"),
204
- gr.Image(label="Original | Heatmap | Lung-Masked CAM")
205
  ],
206
- title="Pneumonia Detection (Soft Grad-CAM + Lung Mask)",
207
- description=(
208
- "Upload an X-ray to view pneumonia predictions with heatmaps and lung-masked Grad-CAM."
209
- )
210
  )
211
 
212
  demo.launch()
 
1
  import torch
2
+ import torch.nn as nn
3
  import torch.nn.functional as F
4
  import timm
5
  from torchvision import transforms
 
7
  import numpy as np
8
  import cv2
9
  import gradio as gr
10
+ import os
11
 
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
  MODEL_PATH = "best_timm_pneumonia.pt"
14
+ SEG_PATHS = ["lung_unet_lite.pt", "lung_unet_lite.ts"] # accepted filenames
15
  IMG_SIZE = 384
16
 
 
17
  # ---------------------------------------------------------
18
+ # Preprocessing
19
  # ---------------------------------------------------------
20
  val_tf = transforms.Compose([
21
  transforms.Grayscale(num_output_channels=3),
 
25
  transforms.Normalize([0.485]*3, [0.229]*3),
26
  ])
27
 
 
28
  # ---------------------------------------------------------
29
+ # Classifier (EfficientNet-B0)
30
  # ---------------------------------------------------------
31
  def build_model():
32
  return timm.create_model("tf_efficientnet_b0_ns", pretrained=False, num_classes=1)
 
35
  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
36
  model.eval()
37
 
38
+ # ---------------------------------------------------------
39
+ # Lightweight Lung U-Net (architecture only; ~5MB weights expected)
40
+ # ---------------------------------------------------------
41
+ class DSConv(nn.Module):
42
+ def __init__(self, c_in, c_out):
43
+ super().__init__()
44
+ self.dw = nn.Conv2d(c_in, c_in, 3, padding=1, groups=c_in, bias=False)
45
+ self.pw = nn.Conv2d(c_in, c_out, 1, bias=False)
46
+ self.bn = nn.BatchNorm2d(c_out)
47
+ self.act = nn.ReLU(inplace=True)
48
+ def forward(self, x):
49
+ x = self.dw(x); x = self.pw(x); x = self.bn(x); return self.act(x)
50
+
51
+ class Block(nn.Module):
52
+ def __init__(self, c_in, c_out):
53
+ super().__init__()
54
+ self.c1 = DSConv(c_in, c_out)
55
+ self.c2 = DSConv(c_out, c_out)
56
+ def forward(self, x): return self.c2(self.c1(x))
57
+
58
+ class Up(nn.Module):
59
+ def __init__(self, c_in, c_out):
60
+ super().__init__()
61
+ self.up = nn.ConvTranspose2d(c_in, c_in//2, 2, stride=2)
62
+ self.conv = Block(c_in, c_out)
63
+ def forward(self, x, skip):
64
+ x = self.up(x)
65
+ # pad if sizes mismatch
66
+ dh, dw = skip.shape[2]-x.shape[2], skip.shape[3]-x.shape[3]
67
+ x = F.pad(x, [dw//2, dw-dw//2, dh//2, dh-dh//2])
68
+ x = torch.cat([skip, x], dim=1)
69
+ return self.conv(x)
70
+
71
+ class LungUNetLite(nn.Module):
72
+ def __init__(self):
73
+ super().__init__()
74
+ self.enc1 = Block(1, 32)
75
+ self.enc2 = Block(32, 64)
76
+ self.enc3 = Block(64, 128)
77
+ self.enc4 = Block(128, 256)
78
+ self.pool = nn.MaxPool2d(2)
79
+ self.bott = Block(256, 256)
80
+ self.up3 = Up(256, 128)
81
+ self.up2 = Up(128, 64)
82
+ self.up1 = Up(64, 32)
83
+ self.outc = nn.Conv2d(32, 1, 1)
84
+
85
+ def forward(self, x):
86
+ e1 = self.enc1(x) # 1/1
87
+ e2 = self.enc2(self.pool(e1)) # 1/2
88
+ e3 = self.enc3(self.pool(e2)) # 1/4
89
+ e4 = self.enc4(self.pool(e3)) # 1/8
90
+ b = self.bott(self.pool(e4)) # 1/16
91
+ x = self.up3(b, e4)
92
+ x = self.up2(x, e3)
93
+ x = self.up1(x, e2)
94
+ x = self.outc(x)
95
+ return x
96
+
97
+ # Try to load a tiny lung segmentation model if the user uploaded it
98
+ lung_net = None
99
+ seg_loaded_msg = ""
100
+ for p in SEG_PATHS:
101
+ if os.path.exists(p):
102
+ try:
103
+ if p.endswith(".ts"): # TorchScript
104
+ lung_net = torch.jit.load(p, map_location=DEVICE)
105
+ else: # state_dict
106
+ lung_net = LungUNetLite().to(DEVICE)
107
+ lung_net.load_state_dict(torch.load(p, map_location=DEVICE))
108
+ lung_net.eval()
109
+ seg_loaded_msg = f"✅ Lung model loaded: {p}"
110
+ break
111
+ except Exception as e:
112
+ seg_loaded_msg = f"⚠️ Failed to load {p}: {e}"
113
+
114
+ if lung_net is None and seg_loaded_msg == "":
115
+ seg_loaded_msg = "ℹ️ Using classical (non-ML) lung mask; upload 'lung_unet_lite.pt' to upgrade."
116
 
117
  # ---------------------------------------------------------
118
+ # Auto-crop borders
119
  # ---------------------------------------------------------
120
  def autocrop_chest(pil_img, th=10):
121
  gray = np.array(pil_img.convert("L"))
122
  mask = gray > th
123
+ if mask.sum() < 10: return pil_img
 
124
  ys, xs = np.where(mask)
125
  return Image.fromarray(gray[ys.min():ys.max(), xs.min():xs.max()]).convert("RGB")
126
 
 
127
  # ---------------------------------------------------------
128
+ # Classical lung mask (fallback, no ML)
129
  # ---------------------------------------------------------
130
+ def classical_lung_mask(pil_img):
131
  img = np.array(pil_img.convert("L"))
 
132
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
133
  g = clahe.apply(img)
 
134
  g = cv2.GaussianBlur(g, (5, 5), 0)
135
  _, th = cv2.threshold(g, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
 
136
  th = cv2.morphologyEx(th, cv2.MORPH_OPEN, np.ones((5, 5), np.uint8), iterations=1)
137
  th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, np.ones((15, 15), np.uint8), iterations=1)
 
138
  num, labels, stats, _ = cv2.connectedComponentsWithStats(th, connectivity=8)
139
  if num > 1:
140
  areas = [(i, stats[i, cv2.CC_STAT_AREA]) for i in range(1, num)]
 
143
  mask = np.where(np.isin(labels, list(keep)), 1.0, 0.0).astype(np.float32)
144
  else:
145
  mask = (th > 0).astype(np.float32)
 
146
  mask = cv2.dilate(mask, np.ones((11, 11), np.uint8), iterations=1)
147
  mask = cv2.GaussianBlur(mask, (21, 21), 0)
 
148
  h, w = mask.shape
149
  b = int(0.06 * w)
150
  mask[:b, :] = mask[-b:, :] = mask[:, :b] = mask[:, -b:] = 0
151
+ mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
152
+ return mask
153
 
154
+ # ---------------------------------------------------------
155
+ # ML lung mask (if weights provided)
156
+ # ---------------------------------------------------------
157
+ def ml_lung_mask(pil_img):
158
+ # expects grayscale 1xHxW scaled to 256 side for speed, then resize back
159
+ img = np.array(pil_img.convert("L"))
160
+ h, w = img.shape
161
+ side = 256
162
+ imr = cv2.resize(img, (side, side))
163
+ tens = torch.from_numpy(imr[None, None].astype(np.float32) / 255.0).to(DEVICE)
164
+ with torch.no_grad():
165
+ logits = lung_net(tens)
166
+ prob = torch.sigmoid(logits)[0, 0].cpu().numpy()
167
+ mask = cv2.resize(prob, (IMG_SIZE, IMG_SIZE))
168
+ # soft refine
169
+ mask = cv2.GaussianBlur(mask, (15, 15), 0)
170
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
171
  return mask
172
 
173
+ def get_lung_mask(pil_img):
174
+ if lung_net is not None:
175
+ return ml_lung_mask(pil_img)
176
+ return classical_lung_mask(pil_img)
177
 
178
  # ---------------------------------------------------------
179
+ # Soft Grad-CAM (correct layer)
180
  # ---------------------------------------------------------
181
  def get_gradcam(model, img_tensor):
182
  target_layer = model.blocks[-1]
 
183
  activ, grads = [], []
184
+ def fwd(m, i, o): activ.append(o)
185
+ def bwd(m, gi, go): grads.append(go[0])
 
 
 
 
 
186
  h1 = target_layer.register_forward_hook(fwd)
187
  h2 = target_layer.register_backward_hook(bwd)
 
188
  logits = model(img_tensor)
189
  score = torch.sigmoid(logits)[0]
190
+ model.zero_grad(); score.backward()
191
+ h1.remove(); h2.remove()
 
 
 
 
192
  A = activ[0][0].detach().cpu().numpy()
193
  G = grads[0].detach().cpu().numpy()
 
194
  weights = G.mean(axis=(1, 2))
195
  cam = np.zeros(A.shape[1:], dtype=np.float32)
196
+ for c, w in enumerate(weights): cam += w * A[c]
 
 
197
  cam = np.maximum(cam, 0)
198
+ cam -= cam.min(); cam /= (cam.max() + 1e-8)
 
 
199
  cam = cv2.GaussianBlur(cam, (7, 7), 0)
200
  return cam
201
 
 
202
  # ---------------------------------------------------------
203
+ # Mode enhancement
204
  # ---------------------------------------------------------
205
  def enhance_cam(cam, mode):
206
+ if mode == "Weak": return cam ** 1.4
207
+ if mode == "Strong": return cam ** 0.6
 
 
208
  return cam ** 0.9 # Medium
209
 
 
210
  # ---------------------------------------------------------
211
+ # Prediction + Visualization
212
  # ---------------------------------------------------------
213
  def predict(img, mode, intensity, threshold):
 
 
214
  cropped = autocrop_chest(img)
215
  resized = cropped.resize((IMG_SIZE, IMG_SIZE))
 
216
  x = val_tf(resized).unsqueeze(0).to(DEVICE)
217
 
218
  with torch.no_grad():
 
224
  # Grad-CAM
225
  cam = get_gradcam(model, x)
226
  cam = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
 
 
227
  cam = np.where(cam >= threshold, cam, cam * 0.3)
228
  cam = enhance_cam(cam, mode)
229
+ cam = cam / (cam.max() + 1e-8)
230
 
231
+ # Heatmap + base (both float32 [0,1])
232
  heat = cv2.applyColorMap((cam * 255).astype(np.uint8), cv2.COLORMAP_JET)[..., ::-1] / 255.0
 
 
233
  base = np.array(resized).astype(np.float32) / 255.0
234
  base_f = base.astype(np.float32)
235
  heat_f = heat.astype(np.float32)
236
 
237
+ # Overlay
238
  overlay = cv2.addWeighted(base_f, 1 - intensity, heat_f, intensity, 0)
239
  overlay = np.clip(overlay * 255, 0, 255).astype(np.uint8)
240
 
241
  # Lung-masked CAM
242
  lung_mask = get_lung_mask(resized)
243
+ masked = (heat_f * lung_mask[..., None])
244
  masked = np.clip(masked * 255, 0, 255).astype(np.uint8)
245
 
 
246
  combined = np.hstack([
247
  (base * 255).astype(np.uint8),
248
  overlay,
249
  masked
250
  ])
251
+ text = f"{seg_loaded_msg}\nPrediction: {label}\nP(PNEUMONIA)={prob_p:.3f} | P(NORMAL)={prob_n:.3f}"
 
252
  return text, combined
253
 
 
254
  # ---------------------------------------------------------
255
+ # UI
256
  # ---------------------------------------------------------
257
  demo = gr.Interface(
258
  fn=predict,
 
260
  gr.Image(type="pil", label="Upload Chest X-Ray"),
261
  gr.Radio(["Weak", "Medium", "Strong"], value="Medium", label="Heatmap Mode"),
262
  gr.Slider(0.1, 1.0, value=0.70, step=0.05, label="Heatmap Intensity"),
263
+ gr.Slider(0.0, 1.0, value=0.40, step=0.05, label="Heatmap Threshold"),
264
  ],
265
  outputs=[
266
+ gr.Text(label="Status + Prediction"),
267
+ gr.Image(label="Original | Heatmap | Lung-Masked CAM"),
268
  ],
269
+ title="Pneumonia Detector (Soft Grad-CAM + Lung Mask, Lite)",
270
+ description="Upload an X-ray. For lung-only CAM, upload 'lung_unet_lite.pt' or 'lung_unet_lite.ts' to this Space."
 
 
271
  )
272
 
273
  demo.launch()