Clocksp commited on
Commit
872bf8e
·
verified ·
1 Parent(s): 681f493

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -17
app.py CHANGED
@@ -171,33 +171,48 @@ def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
171
 
172
  def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
173
  """
174
- Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask (PIL)
175
  """
 
176
  pil = Image.fromarray(img.astype('uint8'), 'RGB')
177
- masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(pil, threshold=seg_thresh)
178
- pb, pv, label = classify_masked_tensor(masked_tensor, thresh_b=thresh_b, thresh_v=thresh_v)
 
 
 
 
 
 
 
 
 
 
179
  mask_vis = (mask_np * 255).astype(np.uint8)
180
  mask_pil = Image.fromarray(mask_vis).convert("L")
181
- display_orig = pil.resize((300,300))
182
- overlay = Image.new("RGBA", display_orig.size)
183
- overlay.paste(display_orig.convert("RGBA"))
184
- # red mask with alpha
185
- red_mask = Image.fromarray(np.zeros((300,300,3), dtype=np.uint8))
186
- red_mask = Image.fromarray(np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2))
187
- red_mask = red_mask.convert("RGBA")
188
- # apply alpha where mask is 1
 
189
  alpha = (mask_np * 120).astype(np.uint8)
190
  red_mask.putalpha(Image.fromarray(alpha))
 
191
  blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
192
- # return values
 
193
  return (
194
- pred_label,
195
- float(prob_bact),
196
- float(prob_viral),
197
- masked_image, # PIL image
198
- overlay_image # PIL image
199
  )
200
 
 
201
  title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
202
  desc = "Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). " \
203
  "If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable."
 
171
 
172
  def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
173
  """
174
+ Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL)
175
  """
176
+
177
  pil = Image.fromarray(img.astype('uint8'), 'RGB')
178
+
179
+ masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(
180
+ pil, threshold=seg_thresh
181
+ )
182
+
183
+ pb, pv, pred_label = classify_masked_tensor(
184
+ masked_tensor,
185
+ thresh_b=thresh_b,
186
+ thresh_v=thresh_v
187
+ )
188
+
189
+ # Convert mask to PIL
190
  mask_vis = (mask_np * 255).astype(np.uint8)
191
  mask_pil = Image.fromarray(mask_vis).convert("L")
192
+
193
+ # Resize original for overlay
194
+ display_orig = pil.resize((300, 300))
195
+
196
+ # Create red mask overlay
197
+ red_mask = np.zeros((300, 300, 3), dtype=np.uint8)
198
+ red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2)
199
+ red_mask = Image.fromarray(red_mask).convert("RGBA")
200
+
201
  alpha = (mask_np * 120).astype(np.uint8)
202
  red_mask.putalpha(Image.fromarray(alpha))
203
+
204
  blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
205
+
206
+
207
  return (
208
+ pred_label,
209
+ float(pb),
210
+ float(pv),
211
+ masked_pil,
212
+ blended
213
  )
214
 
215
+
216
  title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
217
  desc = "Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). " \
218
  "If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable."