Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import requests | |
| from PIL import Image | |
| import torch | |
| from inference import predict, random_sample, overlay_images, post_process | |
| def main(): | |
| pp_options = [ | |
| "None", | |
| "Thresholding", | |
| "Closing", | |
| "Opening", | |
| "Canny Edge", | |
| "Gaussian Smoothing", | |
| "Hysteresis" | |
| ] | |
| def update_slider(post_process): | |
| visibility = [0 for option in pp_options] | |
| if post_process ==pp_options[0]: | |
| # None | |
| pass | |
| else: | |
| # Retrieve index of post_process | |
| assert post_process in pp_options | |
| index = pp_options.index(post_process) | |
| visibility[index] = 1 | |
| ret_updates = [] | |
| for vis in visibility: | |
| if vis == 1: | |
| ret_updates.append(gr.update(visible=True)) | |
| else: | |
| ret_updates.append(gr.update(visible=False)) | |
| return ret_updates | |
| with gr.Blocks() as demo: | |
| # Button to select task | |
| seismic_data = gr.State() | |
| prediction_data = gr.State() | |
| processed_prediction_data = gr.State() | |
| gr.Markdown("## SFM Inference Demo") | |
| gr.Markdown("### Select a task and run inference on seismic data") | |
| with gr.Row(): | |
| task = gr.Radio(choices=['Fault', 'Facies'], label="Select Task", value='Fault') | |
| gr.Markdown("### Upload your seismic data or sample from dataset") | |
| with gr.Row(): | |
| seismic_image = gr.Image(label="Seismic Data") | |
| prediction_image = gr.Image(label="Prediction Result") | |
| with gr.Row(): | |
| random_sample_button = gr.Button("Upload Random Sample", elem_id="random-sample-button") | |
| random_sample_button.click(fn=random_sample, inputs=[task], outputs=[seismic_image, seismic_data]) | |
| with gr.Row(): | |
| predict_button = gr.Button("Run Inference", elem_id="predict-button") | |
| predict_button.click(fn=predict, inputs=[seismic_data, task], outputs=[prediction_image, prediction_data]) | |
| processed_prediction_data = prediction_data | |
| with gr.Row(): | |
| overlay_image = gr.Image(label="Overlay Result") | |
| with gr.Column(): | |
| gr.Markdown("### Overlay Seismic Data with Prediction Result") | |
| overlay_button = gr.Button("Overlay Result", elem_id="overlay-button") | |
| overlay_button.click(fn=overlay_images, inputs=[seismic_image, prediction_image], outputs=[overlay_image]) | |
| gr.Markdown("### Post Processing") | |
| with gr.Row(): | |
| post_process = gr.Radio(choices=pp_options, | |
| value='None', elem_id="post-processing", label="Post Processing Method") | |
| slider_none = gr.Slider(minimum=0, maximum=255, value=128, label="None Value", visible=False) | |
| slider_thresh = gr.Slider(minimum=0, maximum=255, value=128, label="Threshold Value", visible=False) | |
| slider_close = gr.Slider(minimum=0, maximum=64, value=32, label="Closing Value", visible=False) | |
| slider_open = gr.Slider(minimum=0, maximum=64, value=32, label="Opening Value", visible=False) | |
| slider_canny = gr.Slider(minimum=0, maximum=255, value=128, label="Canny Edge Value", visible=False) | |
| slider_gauss = gr.Slider(minimum=0, maximum=255, value=128, label="Sigma", visible=False) | |
| slider_hyst = gr.Slider(minimum=0, maximum=255, value=128, label="Hysteresis Min Value", visible=False) | |
| post_process.change( | |
| fn=update_slider, | |
| inputs=[post_process], | |
| outputs=[slider_none, slider_thresh, slider_close, slider_open, slider_canny, slider_gauss, slider_hyst] | |
| ) | |
| gr.Button("Download Processed Image", elem_id="download-processed-button") | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |