File size: 2,961 Bytes
a9df255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcc2554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a15ce4
a9df255
 
 
d4cb905
8a15ce4
a9df255
dcc2554
 
 
a9df255
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import sys
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import glob
import gradio as gr
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
import logging
from huggingface_hub import hf_hub_download

token = os.environ['HUB_TOKEN']
loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token)
sys.path.append(loc)
from utils import *

with gr.Blocks(theme='gradio/soft') as demo:
    gr.Markdown(title)
    with gr.Accordion("Instructions For User 👉", open=False):
        gr.Markdown(description)
    x=gr.State(value=[])
    y=gr.State(value=[])
    label=gr.State(value=[])
    with gr.Row():
        with gr.Column(scale=13):  
            with gr.Row():
                with gr.Column():  
                    mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
                with gr.Column():
                        clear_bn=gr.Button("Clear Selection")
                        interseg_button = gr.Button("Interactive Segment",variant='primary')
            with gr.Row():
                input_img = gr.Image(label="Input")
                gallery = gr.Image(label="Points")
                
            input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
            
            with gr.Row():
                output_img = gr.Image(label="Result")           
                mask_img = gr.Image(label="Mask")      
            with gr.Row():
                with gr.Column():
                    thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Threshhold")
                with gr.Column():
                    points = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points/Side")
            
        with gr.Column(scale=2,min_width=8):  
            example = gr.Examples(
            examples=[[s,0.9,32] for s in glob.glob('./images/*')],
            fn=auto_seg,
            inputs=[input_img,thresh,points],
            outputs=[output_img],
            cache_examples=False,examples_per_page=5)

    model_type='vit_b'
    autoseg_button = gr.Button("Auto Segment",variant="primary")
    emptyBtn = gr.Button("Restart",variant="secondary")

    interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
    autoseg_button.click(auto_seg, inputs=[input_img,thresh,points], outputs=[mask_img])

    clear_bn.click(clear_point,outputs=[gallery,mode,x,y,label],show_progress=True)
    emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,thresh,points,mode,x,y,label],show_progress=True,)   
        
    gr.Markdown(descriptionend)
if __name__ == "__main__":
    demo.launch(debug=False,show_api=False)