KurtLin commited on
Commit
8896ee3
·
1 Parent(s): dec9639

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +36 -19
app.py CHANGED
@@ -6,15 +6,30 @@ from segment_anything import sam_model_registry, SamPredictor
6
  from preprocess import show_mask, show_points, show_box
7
  import gradio as gr
8
 
9
- def get_coord_infer(evt: gr.SelectData):
10
- return [evt.index[0], evt.index[1]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # sam_checkpoint = "weights/sam_vit_b_01ec64.pth"
13
- # model_type = "vit_b"
14
- # device = "cuda" if torch.cuda.is_available() else "cpu"
15
- # sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
16
- # sam.to(device=device)
17
- # predictor = SamPredictor(sam)
18
 
19
  my_app = gr.Blocks()
20
  with my_app:
@@ -24,19 +39,21 @@ with my_app:
24
  with gr.Row():
25
  with gr.Column():
26
  img_source = gr.Image(label="Please select picture.", value='./images/truck.jpg', shape=(768, 768))
27
- coords = gr.Label(label="Image Coordinate.")
 
28
  with gr.Column():
29
  img_output = gr.Image(label="Output Mask")
30
 
31
- img_source.select(get_coord_infer, [], coords)
32
- # set_point.click(
33
- # img_source.select(get_coord),
34
- # [
35
- # img_source
36
- # ],
37
- # [
38
- # coords
39
- # ]
40
- # )
 
41
 
42
  my_app.launch(debug=True)
 
6
  from preprocess import show_mask, show_points, show_box
7
  import gradio as gr
8
 
9
+ sam_checkpoint = "weights/sam_vit_b_01ec64.pth"
10
+ model_type = "vit_b"
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
13
+ sam.to(device=device)
14
+ predictor = SamPredictor(sam)
15
+
16
+ def get_coords(evt: gr.SelectData):
17
+ return f"{evt.index[0]}, {evt.index[1]}"
18
+
19
+ def inference(image, input_label):
20
+ predictor.set_image(image)
21
+ input_point = np.array([[int(input_label.split(',')[0]), int(input_label.split(',')[])]])
22
+ input_label = np.array([1])
23
+ masks, scores, logits = predictor.predict(
24
+ point_coords=input_point,
25
+ point_labels=input_label,
26
+ multimask_output=True,
27
+ )
28
+ mask = masks[0]
29
+ image2 = image.copy()
30
+ image2[mask, 0] = 255
31
+ return image2
32
 
 
 
 
 
 
 
33
 
34
  my_app = gr.Blocks()
35
  with my_app:
 
39
  with gr.Row():
40
  with gr.Column():
41
  img_source = gr.Image(label="Please select picture.", value='./images/truck.jpg', shape=(768, 768))
42
+ coords = gr.Label(label="Image Coordinate")
43
+ infer = gr.Button(label="Segment")
44
  with gr.Column():
45
  img_output = gr.Image(label="Output Mask")
46
 
47
+ img_source.select(get_coords, [], coords)
48
+ infer.click(
49
+ inference,
50
+ [
51
+ img_source,
52
+ coords
53
+ ],
54
+ [
55
+ img_output
56
+ ]
57
+ )
58
 
59
  my_app.launch(debug=True)