KurtLin commited on
Commit
712d80d
·
1 Parent(s): cec0a1c

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -1,14 +1,21 @@
1
  import numpy as np
2
  import matplotlib.pyplot as plt
3
  import cv2
4
- # from segment_anything import sam_model_registry, SamPredictor
 
5
  from preprocess import show_mask, show_points, show_box
6
  import gradio as gr
7
 
8
- def get_select_coords(evt: gr.SelectData):
9
- position = f"[{evt.index[0]}, {evt.index[1]}]"
10
- return position
11
-
 
 
 
 
 
 
12
  my_app = gr.Blocks()
13
  with my_app:
14
  gr.Markdown("Segment Anything Testing")
@@ -17,11 +24,11 @@ with my_app:
17
  with gr.Row():
18
  with gr.Column():
19
  img_source = gr.Image(label="Please select picture.", value='./images/truck.jpg', shape=(768, 768))
20
- with gr.Column():
21
- # img_output = gr.Image(label="Output Mask")
22
  coords = gr.Label(label="Image Coordinate.")
 
 
23
 
24
- img_source.select(get_select_coords, [], coords)
25
  # set_point.click(
26
  # img_source.select(get_coord),
27
  # [
 
1
  import numpy as np
2
  import matplotlib.pyplot as plt
3
  import cv2
4
+ import torch
5
+ from segment_anything import sam_model_registry, SamPredictor
6
  from preprocess import show_mask, show_points, show_box
7
  import gradio as gr
8
 
9
+ def get_coord_infer(evt: gr.SelectData):
10
+ return [evt.index[0], evt.index[1]]
11
+
12
+ # sam_checkpoint = "weights/sam_vit_b_01ec64.pth"
13
+ # model_type = "vit_b"
14
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ # sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
16
+ # sam.to(device=device)
17
+ # predictor = SamPredictor(sam)
18
+
19
  my_app = gr.Blocks()
20
  with my_app:
21
  gr.Markdown("Segment Anything Testing")
 
24
  with gr.Row():
25
  with gr.Column():
26
  img_source = gr.Image(label="Please select picture.", value='./images/truck.jpg', shape=(768, 768))
 
 
27
  coords = gr.Label(label="Image Coordinate.")
28
+ with gr.Column():
29
+ img_output = gr.Image(label="Output Mask")
30
 
31
+ img_source.select(get_coord_infer, [], coords)
32
  # set_point.click(
33
  # img_source.select(get_coord),
34
  # [