Update app.py
Browse files
app.py
CHANGED
|
@@ -171,33 +171,48 @@ def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
|
|
| 171 |
|
| 172 |
def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
|
| 173 |
"""
|
| 174 |
-
Returns: label, bacterial_prob, viral_prob, masked_image (PIL),
|
| 175 |
"""
|
|
|
|
| 176 |
pil = Image.fromarray(img.astype('uint8'), 'RGB')
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
mask_vis = (mask_np * 255).astype(np.uint8)
|
| 180 |
mask_pil = Image.fromarray(mask_vis).convert("L")
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
red_mask =
|
| 187 |
-
red_mask =
|
| 188 |
-
|
|
|
|
| 189 |
alpha = (mask_np * 120).astype(np.uint8)
|
| 190 |
red_mask.putalpha(Image.fromarray(alpha))
|
|
|
|
| 191 |
blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
|
| 192 |
-
|
|
|
|
| 193 |
return (
|
| 194 |
-
pred_label,
|
| 195 |
-
float(
|
| 196 |
-
float(
|
| 197 |
-
|
| 198 |
-
|
| 199 |
)
|
| 200 |
|
|
|
|
| 201 |
title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
|
| 202 |
desc = "Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). " \
|
| 203 |
"If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable."
|
|
|
|
| 171 |
|
| 172 |
def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
|
| 173 |
"""
|
| 174 |
+
Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL)
|
| 175 |
"""
|
| 176 |
+
|
| 177 |
pil = Image.fromarray(img.astype('uint8'), 'RGB')
|
| 178 |
+
|
| 179 |
+
masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(
|
| 180 |
+
pil, threshold=seg_thresh
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
pb, pv, pred_label = classify_masked_tensor(
|
| 184 |
+
masked_tensor,
|
| 185 |
+
thresh_b=thresh_b,
|
| 186 |
+
thresh_v=thresh_v
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Convert mask to PIL
|
| 190 |
mask_vis = (mask_np * 255).astype(np.uint8)
|
| 191 |
mask_pil = Image.fromarray(mask_vis).convert("L")
|
| 192 |
+
|
| 193 |
+
# Resize original for overlay
|
| 194 |
+
display_orig = pil.resize((300, 300))
|
| 195 |
+
|
| 196 |
+
# Create red mask overlay
|
| 197 |
+
red_mask = np.zeros((300, 300, 3), dtype=np.uint8)
|
| 198 |
+
red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2)
|
| 199 |
+
red_mask = Image.fromarray(red_mask).convert("RGBA")
|
| 200 |
+
|
| 201 |
alpha = (mask_np * 120).astype(np.uint8)
|
| 202 |
red_mask.putalpha(Image.fromarray(alpha))
|
| 203 |
+
|
| 204 |
blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
return (
|
| 208 |
+
pred_label,
|
| 209 |
+
float(pb),
|
| 210 |
+
float(pv),
|
| 211 |
+
masked_pil,
|
| 212 |
+
blended
|
| 213 |
)
|
| 214 |
|
| 215 |
+
|
| 216 |
title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
|
| 217 |
desc = "Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). " \
|
| 218 |
"If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable."
|