KurtLin commited on
Commit
c1521d5
·
1 Parent(s): 00a3d3d

Initial Commit

Browse files
Files changed (2) hide show
  1. app.py +3 -1
  2. preprocess.py +0 -23
app.py CHANGED
@@ -3,7 +3,6 @@ import matplotlib.pyplot as plt
3
  import cv2
4
  import torch
5
  from segment_anything import sam_model_registry, SamPredictor
6
- from preprocess import show_mask, show_points, show_box
7
  import gradio as gr
8
 
9
  sam_checkpoint = {
@@ -37,6 +36,9 @@ def inference(image, input_label, model_choice):
37
  mask = masks[0]
38
  image2 = image.copy()
39
  image2[mask, 0] = 255
 
 
 
40
  return image2
41
 
42
 
 
3
  import cv2
4
  import torch
5
  from segment_anything import sam_model_registry, SamPredictor
 
6
  import gradio as gr
7
 
8
  sam_checkpoint = {
 
36
  mask = masks[0]
37
  image2 = image.copy()
38
  image2[mask, 0] = 255
39
+ image2[int(input_label['label'].split(',')[0])-10:int(input_label['label'].split(',')[0])+10,
40
+ int(input_label['label'].split(',')[1])-10:int(input_label['label'].split(',')[1])+10,
41
+ 2] = 255
42
  return image2
43
 
44
 
preprocess.py DELETED
@@ -1,23 +0,0 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
- import cv2
4
-
5
- def show_mask(mask, ax, random_color=False):
6
- if random_color:
7
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
8
- else:
9
- color = np.array([30/255, 144/255, 255/255, 0.6])
10
- h, w = mask.shape[-2:]
11
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
12
- ax.imshow(mask_image)
13
-
14
- def show_points(coords, labels, ax, marker_size=375):
15
- pos_points = coords[labels==1]
16
- neg_points = coords[labels==0]
17
- ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
18
- ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
19
-
20
- def show_box(box, ax):
21
- x0, y0 = box[0], box[1]
22
- w, h = box[2] - box[0], box[3] - box[1]
23
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))