File size: 4,072 Bytes
6e87dc7
f129f93
 
 
 
 
b7ab39c
6e87dc7
1316be2
b7ab39c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fc8c87
 
 
 
a81555c
 
 
1fc8c87
a81555c
1fc8c87
 
 
 
 
a81555c
1fc8c87
 
 
 
 
 
 
 
a81555c
 
 
 
 
 
 
 
 
 
b7ab39c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a81555c
b7ab39c
a81555c
 
1fc8c87
a81555c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch 
import numpy as np 
import requests
from PIL import Image
import torch 
from inference import predict, random_sample, overlay_images, post_process

def main(): 

    pp_options = [
        "None",
        "Thresholding",
        "Closing",
        "Opening",
        "Canny Edge",
        "Gaussian Smoothing",
        "Hysteresis"
    ]

    def update_slider(post_process):
        visibility = [0 for option in pp_options]
        if post_process ==pp_options[0]:
            # None
            pass
        else: 
            # Retrieve index of post_process
            assert post_process in pp_options
            index = pp_options.index(post_process)
            visibility[index] = 1
        
        ret_updates = []
        for vis in visibility:
            if vis == 1: 
                ret_updates.append(gr.update(visible=True))
            else:
                ret_updates.append(gr.update(visible=False))
        
        return ret_updates
    

    with gr.Blocks() as demo:
        # Button to select task
        
        seismic_data = gr.State()
        prediction_data = gr.State()
        processed_prediction_data = gr.State()

        gr.Markdown("## SFM Inference Demo")
        gr.Markdown("### Select a task and run inference on seismic data")
        with gr.Row():
            task = gr.Radio(choices=['Fault', 'Facies'], label="Select Task", value='Fault')
        gr.Markdown("### Upload your seismic data or sample from dataset")

        with gr.Row():
            seismic_image = gr.Image(label="Seismic Data")
            prediction_image = gr.Image(label="Prediction Result")

        with gr.Row():
            random_sample_button = gr.Button("Upload Random Sample", elem_id="random-sample-button")
            random_sample_button.click(fn=random_sample, inputs=[task], outputs=[seismic_image, seismic_data])

        with gr.Row():
            predict_button = gr.Button("Run Inference", elem_id="predict-button")
            predict_button.click(fn=predict, inputs=[seismic_data, task], outputs=[prediction_image, prediction_data])
            processed_prediction_data = prediction_data

        with gr.Row(): 
            overlay_image = gr.Image(label="Overlay Result")
            with gr.Column():
                gr.Markdown("### Overlay Seismic Data with Prediction Result")
                overlay_button = gr.Button("Overlay Result", elem_id="overlay-button")
                overlay_button.click(fn=overlay_images, inputs=[seismic_image, prediction_image], outputs=[overlay_image])
                gr.Markdown("### Post Processing")
                with gr.Row():
                    post_process = gr.Radio(choices=pp_options, 
                                            value='None', elem_id="post-processing", label="Post Processing Method")
                slider_none = gr.Slider(minimum=0, maximum=255, value=128, label="None Value", visible=False)
                slider_thresh = gr.Slider(minimum=0, maximum=255, value=128, label="Threshold Value", visible=False)
                slider_close = gr.Slider(minimum=0, maximum=64, value=32, label="Closing Value", visible=False)
                slider_open = gr.Slider(minimum=0, maximum=64, value=32, label="Opening Value", visible=False)
                slider_canny = gr.Slider(minimum=0, maximum=255, value=128, label="Canny Edge Value", visible=False)
                slider_gauss = gr.Slider(minimum=0, maximum=255, value=128, label="Sigma", visible=False)
                slider_hyst = gr.Slider(minimum=0, maximum=255, value=128, label="Hysteresis Min Value", visible=False)
                post_process.change(
                        fn=update_slider,
                        inputs=[post_process],
                        outputs=[slider_none, slider_thresh, slider_close, slider_open, slider_canny, slider_gauss, slider_hyst]
                    )
                gr.Button("Download Processed Image", elem_id="download-processed-button")
                
    demo.launch()


if __name__ == "__main__":
    main()