Update app.py
Browse files
app.py
CHANGED
|
@@ -54,7 +54,7 @@ def perform_prediction(image):
|
|
| 54 |
mask_image = show_mask_on_image(image, predicted_mask[:, max_iou_index], return_image=True)
|
| 55 |
return mask_image
|
| 56 |
|
| 57 |
-
#Function to overlay mask on the image
|
| 58 |
def show_mask_on_image(raw_image, mask, return_image=False):
|
| 59 |
if not isinstance(mask, torch.Tensor):
|
| 60 |
mask = torch.Tensor(mask)
|
|
@@ -62,7 +62,7 @@ def show_mask_on_image(raw_image, mask, return_image=False):
|
|
| 62 |
if len(mask.shape) == 4:
|
| 63 |
mask = mask.squeeze()
|
| 64 |
|
| 65 |
-
fig, axes = plt.subplots(1,1,figsize=(15,15))
|
| 66 |
|
| 67 |
mask = mask.cpu().detach()
|
| 68 |
axes.imshow(np.array(raw_image))
|
|
@@ -73,9 +73,9 @@ def show_mask_on_image(raw_image, mask, return_image=False):
|
|
| 73 |
if return_image:
|
| 74 |
fig = plt.gcf()
|
| 75 |
fig.canvas.draw()
|
| 76 |
-
#Convert plot to image
|
| 77 |
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
| 78 |
-
img = img.reshape(fig.canvas.get_width_height()[::-1]+(3,))
|
| 79 |
img = Image.fromarray(img)
|
| 80 |
plt.close(fig)
|
| 81 |
return img
|
|
|
|
| 54 |
mask_image = show_mask_on_image(image, predicted_mask[:, max_iou_index], return_image=True)
|
| 55 |
return mask_image
|
| 56 |
|
| 57 |
+
# Function to overlay mask on the image
|
| 58 |
def show_mask_on_image(raw_image, mask, return_image=False):
|
| 59 |
if not isinstance(mask, torch.Tensor):
|
| 60 |
mask = torch.Tensor(mask)
|
|
|
|
| 62 |
if len(mask.shape) == 4:
|
| 63 |
mask = mask.squeeze()
|
| 64 |
|
| 65 |
+
fig, axes = plt.subplots(1, 1, figsize=(15, 15))
|
| 66 |
|
| 67 |
mask = mask.cpu().detach()
|
| 68 |
axes.imshow(np.array(raw_image))
|
|
|
|
| 73 |
if return_image:
|
| 74 |
fig = plt.gcf()
|
| 75 |
fig.canvas.draw()
|
| 76 |
+
# Convert plot to image
|
| 77 |
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
| 78 |
+
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
| 79 |
img = Image.fromarray(img)
|
| 80 |
plt.close(fig)
|
| 81 |
return img
|