import gradio as gr import matplotlib.pyplot as plt import s3fs import xarray as xr import zarr import fsspec # S3 endpoint ENDPOINT = "https://s3.echo.stfc.ac.uk" # FS = s3fs.S3FileSystem(endpoint_url=ENDPOINT, anon=True) # For level 1 dataset FS = fsspec.filesystem( "simplecache", target_protocol="s3", target_options=dict(anon=True, endpoint_url=ENDPOINT), # cache_storage="./cache_zarr", # same_names=True, ) MAX_PLOT = 5 with gr.Blocks(title="FAIR-MAST Signal Viewer") as demo: # All gradio thigns must live inside this thing. # This store_state can used to hold the FSStore of the zarray and avoid accessing the S3 bucket if caching wasn't working. # When we press the load button the store_state is update with the FSStore. store_state = gr.State(None) gr.Markdown("# FAIR-MAST Signal Viewer") with gr.Row(): shot_box = gr.Textbox(label="Shot ID", value= "23447", placeholder="23447") load_btn = gr.Button("Load Groups (Avoid if you just want to plot predefined signals.)") # with gr.Column(): with gr.Accordion("Signal Selection (WIP)", open=False): with gr.Row(): group_dropdown = gr.Dropdown(label="Group (folder)") var_dropdown = gr.Dropdown(label="Signal / Variable", multiselect=True) plot_btn = gr.Button("Plot") plot_output = gr.Plot() with gr.Accordion("Plot Templates", open=True): plot_templates = gr.CheckboxGroup( label="Choose one or more templates:", choices=[ "Plasma current (ip)", "Mirnov coil spectrogram (OMV)", "Core density (ne_core)", "Core temperature (te_core)" ], value=["Plasma current (ip)"], interactive=True # optional default ) plot_btn = gr.Button("Plot Selected") # Functions def list_groups(shot_id): """ Return top-level groups available in the Zarr. We create the store one time and save it to a global variable. Stupid idea, but it works. """ store = zarr.storage.FSStore(fs=FS, url=f"s3://mast/level2/shots/{shot_id}.zarr") # one time only root = zarr.open_group(store, mode="r") return store, gr.update(choices=list(set(list(root.group_keys()))), value=None, interactive=True) def list_signals(store, group_name): """Return all signals in a list.""" print("###########################") print(group_name) print("###########################") group = xr.open_zarr(store, group=group_name) print("###########################") print(list(group.data_vars.keys())) print("###########################") return gr.update(choices = list(group.data_vars.keys()), value=None, interactive=True) def plot_predefined(shot_id, templates): if not shot_id: return ["Error: no shot ID entered"] store = zarr.storage.FSStore(fs=FS, url=f"s3://mast/level2/shots/{shot_id}.zarr") # store = zarr.storage.FSStore(fs=FS, url=f"s3://mast/level2/shots/{shot_id}.zarr") figs = [] # start clean each time # for t in templates: print("###########################") print(templates) print(store) print("###########################") updates = [] for t in templates: plt.figure(figsize=(8, 4)) if "Plasma current (ip)" == t: group = xr.open_zarr(store, group="summary") group["ip"].plot() plt.title("Plasma current (ip)") elif "Core density (ne_core)" in templates: group = xr.open_zarr(store, group="thomson_scattering") group["n_e_core"].plot() plt.title("Core electron density") elif "Core temperature (te_core)" in templates: group = xr.open_zarr(store, group="thomson_scattering") group["t_e_core"].plot() plt.title("Core electron temperature") elif "OMV" in t: group = xr.open_zarr(store, group="magnetics") group["b_field_pol_probe_omv_voltage"].isel(channel=0).plot() plt.title("OMV Mirnov coil voltage") plt.grid(True) plt.tight_layout() fig = plt.gcf() plt.close(fig) updates.append(gr.update(value=fig, visible=True)) # Hide remaining plots for _ in range(len(updates), MAX_PLOT): updates.append(gr.update(value=None, visible=False)) return updates # ---- Logic wiring ---- load_btn.click(list_groups, inputs=shot_box, outputs=[store_state, group_dropdown]) group_dropdown.change(list_signals, inputs=[store_state, group_dropdown], outputs=var_dropdown) # --------- plot template button and plots ------ # With # 5 pre-created plot placeholders with gr.Accordion("Plots", open=True): plots = [gr.Plot(visible=False) for _ in range(5)] plot_btn.click(plot_predefined, inputs=[shot_box, plot_templates], outputs=plots) # plot_btn.click(plot_signal, inputs=[shot_box, group_dropdown, var_dropdown], outputs=plot_output) # var_select = gr.Dropdown(label="Select signals", multiselect=True) # plot_btn = gr.Button("Plot Selected") # plot_out = gr.Plot() # def load_vars(shot_id): # store = zarr.storage.FSStore(fs=FS, url=f"s3://mast/level2/shots/{shot_id}.zarr") # group = xr.open_zarr(store, group='summary') # return gr.update(choices=list(group.data_vars.keys())) # def plot_selected(shot_id, vars): # return make_plots(shot_id, vars) # load_btn.click(load_vars, inputs=shot_id, outputs=var_select) # plot_btn.click(plot_selected, inputs=[shot_id, var_select], outputs=plot_out) # demo.launch() demo.launch(debug=True, show_error=True)