Update app.py
Browse files
app.py
CHANGED
|
@@ -138,9 +138,7 @@ def get_segmenter_transform():
|
|
| 138 |
def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
|
| 139 |
probs = torch.softmax(output, dim=1).cpu().numpy().squeeze()
|
| 140 |
pred_class = np.argmax(probs, axis=0)
|
| 141 |
-
final_mask =
|
| 142 |
-
final_mask[(pred_class == 1) | (pred_class == 4)] = 1
|
| 143 |
-
final_mask[(pred_class == 2) | (pred_class == 3)] = 2
|
| 144 |
return final_mask, probs
|
| 145 |
|
| 146 |
# ====================== VISUALIZATION ======================
|
|
@@ -219,10 +217,8 @@ def main():
|
|
| 219 |
with st.spinner("Detecting lesions..."):
|
| 220 |
seg_results = segment_image(original_image, segmenter)
|
| 221 |
overlay = create_lesion_overlay(original_image, seg_results['mask'])
|
| 222 |
-
heat_bright = create_heatmap(seg_results['probs'][1]
|
| 223 |
-
|
| 224 |
-
heat_red = create_heatmap(seg_results['probs'][2] + seg_results['probs'][3],
|
| 225 |
-
original_image.size)
|
| 226 |
|
| 227 |
with col2:
|
| 228 |
st.image(overlay, caption="Lesion Overlay", use_container_width=True)
|
|
|
|
| 138 |
def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
|
| 139 |
probs = torch.softmax(output, dim=1).cpu().numpy().squeeze()
|
| 140 |
pred_class = np.argmax(probs, axis=0)
|
| 141 |
+
final_mask = pred_class.astype(np.uint8) # Already 0=bg, 1=bright, 2=red
|
|
|
|
|
|
|
| 142 |
return final_mask, probs
|
| 143 |
|
| 144 |
# ====================== VISUALIZATION ======================
|
|
|
|
| 217 |
with st.spinner("Detecting lesions..."):
|
| 218 |
seg_results = segment_image(original_image, segmenter)
|
| 219 |
overlay = create_lesion_overlay(original_image, seg_results['mask'])
|
| 220 |
+
heat_bright = create_heatmap(seg_results['probs'][1], original_image.size)
|
| 221 |
+
heat_red = create_heatmap(seg_results['probs'][2], original_image.size) # Only class 2
|
|
|
|
|
|
|
| 222 |
|
| 223 |
with col2:
|
| 224 |
st.image(overlay, caption="Lesion Overlay", use_container_width=True)
|