Anirudh Bhalekar
sliders
b7ab39c
import gradio as gr
import torch
import numpy as np
import requests
from PIL import Image
import torch
from inference import predict, random_sample, overlay_images, post_process
def main():
pp_options = [
"None",
"Thresholding",
"Closing",
"Opening",
"Canny Edge",
"Gaussian Smoothing",
"Hysteresis"
]
def update_slider(post_process):
visibility = [0 for option in pp_options]
if post_process ==pp_options[0]:
# None
pass
else:
# Retrieve index of post_process
assert post_process in pp_options
index = pp_options.index(post_process)
visibility[index] = 1
ret_updates = []
for vis in visibility:
if vis == 1:
ret_updates.append(gr.update(visible=True))
else:
ret_updates.append(gr.update(visible=False))
return ret_updates
with gr.Blocks() as demo:
# Button to select task
seismic_data = gr.State()
prediction_data = gr.State()
processed_prediction_data = gr.State()
gr.Markdown("## SFM Inference Demo")
gr.Markdown("### Select a task and run inference on seismic data")
with gr.Row():
task = gr.Radio(choices=['Fault', 'Facies'], label="Select Task", value='Fault')
gr.Markdown("### Upload your seismic data or sample from dataset")
with gr.Row():
seismic_image = gr.Image(label="Seismic Data")
prediction_image = gr.Image(label="Prediction Result")
with gr.Row():
random_sample_button = gr.Button("Upload Random Sample", elem_id="random-sample-button")
random_sample_button.click(fn=random_sample, inputs=[task], outputs=[seismic_image, seismic_data])
with gr.Row():
predict_button = gr.Button("Run Inference", elem_id="predict-button")
predict_button.click(fn=predict, inputs=[seismic_data, task], outputs=[prediction_image, prediction_data])
processed_prediction_data = prediction_data
with gr.Row():
overlay_image = gr.Image(label="Overlay Result")
with gr.Column():
gr.Markdown("### Overlay Seismic Data with Prediction Result")
overlay_button = gr.Button("Overlay Result", elem_id="overlay-button")
overlay_button.click(fn=overlay_images, inputs=[seismic_image, prediction_image], outputs=[overlay_image])
gr.Markdown("### Post Processing")
with gr.Row():
post_process = gr.Radio(choices=pp_options,
value='None', elem_id="post-processing", label="Post Processing Method")
slider_none = gr.Slider(minimum=0, maximum=255, value=128, label="None Value", visible=False)
slider_thresh = gr.Slider(minimum=0, maximum=255, value=128, label="Threshold Value", visible=False)
slider_close = gr.Slider(minimum=0, maximum=64, value=32, label="Closing Value", visible=False)
slider_open = gr.Slider(minimum=0, maximum=64, value=32, label="Opening Value", visible=False)
slider_canny = gr.Slider(minimum=0, maximum=255, value=128, label="Canny Edge Value", visible=False)
slider_gauss = gr.Slider(minimum=0, maximum=255, value=128, label="Sigma", visible=False)
slider_hyst = gr.Slider(minimum=0, maximum=255, value=128, label="Hysteresis Min Value", visible=False)
post_process.change(
fn=update_slider,
inputs=[post_process],
outputs=[slider_none, slider_thresh, slider_close, slider_open, slider_canny, slider_gauss, slider_hyst]
)
gr.Button("Download Processed Image", elem_id="download-processed-button")
demo.launch()
if __name__ == "__main__":
main()