Spaces:
Runtime error
Runtime error
Initial Commit
Browse files- app.py +3 -1
- preprocess.py +0 -23
app.py
CHANGED
|
@@ -3,7 +3,6 @@ import matplotlib.pyplot as plt
|
|
| 3 |
import cv2
|
| 4 |
import torch
|
| 5 |
from segment_anything import sam_model_registry, SamPredictor
|
| 6 |
-
from preprocess import show_mask, show_points, show_box
|
| 7 |
import gradio as gr
|
| 8 |
|
| 9 |
sam_checkpoint = {
|
|
@@ -37,6 +36,9 @@ def inference(image, input_label, model_choice):
|
|
| 37 |
mask = masks[0]
|
| 38 |
image2 = image.copy()
|
| 39 |
image2[mask, 0] = 255
|
|
|
|
|
|
|
|
|
|
| 40 |
return image2
|
| 41 |
|
| 42 |
|
|
|
|
| 3 |
import cv2
|
| 4 |
import torch
|
| 5 |
from segment_anything import sam_model_registry, SamPredictor
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
|
| 8 |
sam_checkpoint = {
|
|
|
|
| 36 |
mask = masks[0]
|
| 37 |
image2 = image.copy()
|
| 38 |
image2[mask, 0] = 255
|
| 39 |
+
image2[int(input_label['label'].split(',')[0])-10:int(input_label['label'].split(',')[0])+10,
|
| 40 |
+
int(input_label['label'].split(',')[1])-10:int(input_label['label'].split(',')[1])+10,
|
| 41 |
+
2] = 255
|
| 42 |
return image2
|
| 43 |
|
| 44 |
|
preprocess.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import matplotlib.pyplot as plt
|
| 3 |
-
import cv2
|
| 4 |
-
|
| 5 |
-
def show_mask(mask, ax, random_color=False):
|
| 6 |
-
if random_color:
|
| 7 |
-
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 8 |
-
else:
|
| 9 |
-
color = np.array([30/255, 144/255, 255/255, 0.6])
|
| 10 |
-
h, w = mask.shape[-2:]
|
| 11 |
-
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 12 |
-
ax.imshow(mask_image)
|
| 13 |
-
|
| 14 |
-
def show_points(coords, labels, ax, marker_size=375):
|
| 15 |
-
pos_points = coords[labels==1]
|
| 16 |
-
neg_points = coords[labels==0]
|
| 17 |
-
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
| 18 |
-
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
| 19 |
-
|
| 20 |
-
def show_box(box, ax):
|
| 21 |
-
x0, y0 = box[0], box[1]
|
| 22 |
-
w, h = box[2] - box[0], box[3] - box[1]
|
| 23 |
-
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|