rnmee commited on
Commit
d6c353f
·
verified ·
1 Parent(s): 171a43f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -138,9 +138,7 @@ def get_segmenter_transform():
138
  def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
139
  probs = torch.softmax(output, dim=1).cpu().numpy().squeeze()
140
  pred_class = np.argmax(probs, axis=0)
141
- final_mask = np.zeros_like(pred_class, dtype=np.uint8)
142
- final_mask[(pred_class == 1) | (pred_class == 4)] = 1
143
- final_mask[(pred_class == 2) | (pred_class == 3)] = 2
144
  return final_mask, probs
145
 
146
  # ====================== VISUALIZATION ======================
@@ -219,10 +217,8 @@ def main():
219
  with st.spinner("Detecting lesions..."):
220
  seg_results = segment_image(original_image, segmenter)
221
  overlay = create_lesion_overlay(original_image, seg_results['mask'])
222
- heat_bright = create_heatmap(seg_results['probs'][1] + seg_results['probs'][4],
223
- original_image.size)
224
- heat_red = create_heatmap(seg_results['probs'][2] + seg_results['probs'][3],
225
- original_image.size)
226
 
227
  with col2:
228
  st.image(overlay, caption="Lesion Overlay", use_container_width=True)
 
138
  def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
139
  probs = torch.softmax(output, dim=1).cpu().numpy().squeeze()
140
  pred_class = np.argmax(probs, axis=0)
141
+ final_mask = pred_class.astype(np.uint8) # Already 0=bg, 1=bright, 2=red
 
 
142
  return final_mask, probs
143
 
144
  # ====================== VISUALIZATION ======================
 
217
  with st.spinner("Detecting lesions..."):
218
  seg_results = segment_image(original_image, segmenter)
219
  overlay = create_lesion_overlay(original_image, seg_results['mask'])
220
+ heat_bright = create_heatmap(seg_results['probs'][1], original_image.size)
221
+ heat_red = create_heatmap(seg_results['probs'][2], original_image.size) # Only class 2
 
 
222
 
223
  with col2:
224
  st.image(overlay, caption="Lesion Overlay", use_container_width=True)