File size: 3,029 Bytes
907462b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from samhi.inference.inference import Inference
from samhi.inference.app_tools import App
import gradio as gr
from gradio_image_prompter import ImagePrompter
import argparse

parser = argparse.ArgumentParser()

parser.add_argument("--model_dir", type=str, default="/vol/data/models/")
parser.add_argument("--model_name", type=str, default="model-ap0xl4l1:v19")
parser.add_argument("--cellvit_model", type=str, default="CellViT-256-x40.pth")


args = parser.parse_args()

inference_config = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "model_dir": args.model_dir,
    "model_name": args.model_name,
}
app_config = {
    
}

config = {
    "inference_config": inference_config,
    "app_config": app_config,
}
app = App(config)

with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo: 
    with gr.Row():
        whole_image = ImagePrompter(label="Whole Image")
        whole_image.upload(app.load_image, inputs=[whole_image], outputs=[whole_image])
        with gr.Column(visible=True) as initial_buttons:
            auto_segment_btn = gr.Button("Auto Segment")
            auto_segment_btn.click(app.segment_automatically_app, inputs=[], outputs=[whole_image])
            add_mask_btn = gr.Button("Add Mask",)
            add_mask_btn.click(app.add_mask, inputs=[whole_image], outputs=[whole_image])
            start_refine_mask_btn = gr.Button("Refine Mask")
            remove_mask_btn = gr.Button("Remove Mask")
            remove_mask_btn.click(app.remove_mask, inputs=[whole_image], outputs=[whole_image])
        with gr.Column(visible=False) as refine_buttons:
            refine_mask_btn = gr.Button("Refine")
            refine_mask_btn.click(app.refine_mask, inputs=[whole_image], outputs=[whole_image])
            finish_mask_btn = gr.Button("Finish Mask")  
        start_refine_mask_btn.click(app.start_refine_mask, inputs=[whole_image], outputs=[whole_image, refine_buttons, initial_buttons])
        finish_mask_btn.click(app.finish_mask, inputs=[], outputs=[whole_image, initial_buttons, refine_buttons])
    with gr.Row():
        amount = gr.Number(value=100, label="Step Size", visible = False)
        #amount = 100
        with gr.Column():
            left_button = gr.Button(value="\U0001F814")
            left_button.click(app.move_left, inputs=[amount], outputs=[whole_image])
        with gr.Column():
            up_button = gr.Button(value="\U0001F815")
            up_button.click(app.move_up, inputs=[amount], outputs=[whole_image])
            down_button = gr.Button(value="\U0001F817")
            down_button.click(app.move_down, inputs=[amount], outputs=[whole_image])
        with gr.Column():
            right_button = gr.Button(value="\U0001F816")
            right_button.click(app.move_right, inputs=[amount], outputs=[whole_image])
        zoom_bar = gr.Slider(minimum=1, maximum=5, step=1, label="Zoom Factor", value=1)
        zoom_bar.release(app.zoom, inputs=[zoom_bar, whole_image], outputs=[whole_image])
demo.launch(share=True)