JunchuanYu commited on
Commit
a9df255
·
1 Parent(s): 3ffc2a6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +67 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import cv2
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ import torchvision
9
+ import glob
10
+ import gradio as gr
11
+ from PIL import Image
12
+ from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
13
+ import logging
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ token = os.environ['HUB_TOKEN']
17
+ loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token)
18
+ sys.path.append(loc)
19
+ from utils import *
20
+
21
+ with gr.Blocks(theme='gradio/soft') as demo:
22
+ gr.Markdown(title)
23
+ with gr.Accordion("Instructions For User 👉", open=False):
24
+ gr.Markdown(description)
25
+ x=gr.State(value=[])
26
+ y=gr.State(value=[])
27
+ label=gr.State(value=[])
28
+
29
+ with gr.Row():
30
+ with gr.Column():
31
+ mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
32
+ with gr.Column():
33
+ clear_bn=gr.Button("Clear Selection")
34
+ interseg_button = gr.Button("Interactive Segment",variant='primary')
35
+ with gr.Row():
36
+ input_img = gr.Image(label="Input")
37
+ gallery = gr.Image(label="Selected Sample Points")
38
+
39
+ input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
40
+
41
+ with gr.Row():
42
+ output_img = gr.Image(label="Result")
43
+ mask_img = gr.Image(label="Mask")
44
+ with gr.Row():
45
+ with gr.Column():
46
+ pred_iou_thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Prediction Thresh")
47
+ with gr.Column():
48
+ points_per_side = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points Per Side")
49
+ autoseg_button = gr.Button("Auto Segment",variant="primary")
50
+ emptyBtn = gr.Button("Restart",variant="secondary")
51
+
52
+ interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
53
+ autoseg_button.click(auto_seg, inputs=[input_img,pred_iou_thresh,points_per_side], outputs=[mask_img])
54
+
55
+ clear_bn.click(clear_point,outputs=[gallery,x,y,label],show_progress=True)
56
+ emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,x,y,label],show_progress=True,)
57
+
58
+ example = gr.Examples(
59
+ examples=[[s,0.88,32] for s in glob.glob('./images/*')],
60
+ fn=auto_seg,
61
+ inputs=[input_img,pred_iou_thresh,points_per_side],
62
+ outputs=[output_img],
63
+ cache_examples=False,examples_per_page=5)
64
+
65
+ gr.Markdown(descriptionend)
66
+ if __name__ == "__main__":
67
+ demo.launch(debug=False,show_api=False)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ matplotlib
3
+ numpy
4
+ torch
5
+ torchvision
6
+ https://files.pythonhosted.org/packages/f0/be/fd3e87763d13936186016eceee874c3ecc55a69213e8cc18e800e5fbc4b3/gradio-3.26.0-py3-none-any.whl
7
+ git+https://github.com/facebookresearch/segment-anything.git