Yao Zhang commited on
Commit
6404238
·
1 Parent(s): d83bbe5
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -26,12 +26,15 @@ def visualize_instance_seg_mask(mask):
26
  labels = np.unique(mask)
27
  label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels if label > 0}
28
  label2color[0] = (0, 0, 0)
29
- # for label in labels:
30
- # image[mask==label, :]
31
- for i in range(image.shape[0]):
32
- for j in range(image.shape[1]):
33
- image[i, j, :] = label2color[mask[i, j]]
34
- image = image / 255
 
 
 
35
  return image
36
 
37
 
@@ -141,11 +144,12 @@ def predict(img, threshold=0.5):
141
  else:
142
  img_data = io.imread(img_name)
143
  seg_labels = get_seg(preprocess(img_data), 'swinunetr', './best_Dice_model.pth', float(threshold))
144
- # seg_rgb = visualize_instance_seg_mask(seg_labels)
145
- seg_rgb = seg_labels
146
 
147
  tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib')
148
 
 
 
149
  return img_data, seg_rgb, join(os.getcwd(), 'segmentation.tiff')
150
 
151
 
 
26
  labels = np.unique(mask)
27
  label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels if label > 0}
28
  label2color[0] = (0, 0, 0)
29
+ for label in labels:
30
+ image[mask==label, :] = label2color[label]
31
+ # for i in range(image.shape[0]):
32
+ # for j in range(image.shape[1]):
33
+ # if np.max(label2color[mask[i, j]]) > 0:
34
+ # print('####', np.max(label2color[mask[i, j]]), np.min(label2color[mask[i, j]]))
35
+ # image[i, j, :] = label2color[mask[i, j]]
36
+ # image = image / 255
37
+ image = image.astype(np.uint8)
38
  return image
39
 
40
 
 
144
  else:
145
  img_data = io.imread(img_name)
146
  seg_labels = get_seg(preprocess(img_data), 'swinunetr', './best_Dice_model.pth', float(threshold))
147
+ seg_rgb = visualize_instance_seg_mask(seg_labels)
 
148
 
149
  tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib')
150
 
151
+ print(np.max(img_data), np.min(img_data))
152
+ print(np.max(seg_rgb), np.min(seg_rgb))
153
  return img_data, seg_rgb, join(os.getcwd(), 'segmentation.tiff')
154
 
155