demo / app.py
Prakhar Sharma
multi plot setup
4044777
import gradio as gr
import matplotlib.pyplot as plt
import s3fs
import xarray as xr
import zarr
import fsspec
# S3 endpoint
ENDPOINT = "https://s3.echo.stfc.ac.uk"
# FS = s3fs.S3FileSystem(endpoint_url=ENDPOINT, anon=True) # For level 1 dataset
FS = fsspec.filesystem(
"simplecache",
target_protocol="s3",
target_options=dict(anon=True, endpoint_url=ENDPOINT),
# cache_storage="./cache_zarr",
# same_names=True,
)
MAX_PLOT = 5
with gr.Blocks(title="FAIR-MAST Signal Viewer") as demo: # All gradio thigns must live inside this thing.
# This store_state can used to hold the FSStore of the zarray and avoid accessing the S3 bucket if caching wasn't working.
# When we press the load button the store_state is update with the FSStore.
store_state = gr.State(None)
gr.Markdown("# FAIR-MAST Signal Viewer")
with gr.Row():
shot_box = gr.Textbox(label="Shot ID", value= "23447", placeholder="23447")
load_btn = gr.Button("Load Groups (Avoid if you just want to plot predefined signals.)")
# with gr.Column():
with gr.Accordion("Signal Selection (WIP)", open=False):
with gr.Row():
group_dropdown = gr.Dropdown(label="Group (folder)")
var_dropdown = gr.Dropdown(label="Signal / Variable", multiselect=True)
plot_btn = gr.Button("Plot")
plot_output = gr.Plot()
with gr.Accordion("Plot Templates", open=True):
plot_templates = gr.CheckboxGroup(
label="Choose one or more templates:",
choices=[
"Plasma current (ip)",
"Mirnov coil spectrogram (OMV)",
"Core density (ne_core)",
"Core temperature (te_core)"
],
value=["Plasma current (ip)"], interactive=True # optional default
)
plot_btn = gr.Button("Plot Selected")
# Functions
def list_groups(shot_id):
"""
Return top-level groups available in the Zarr.
We create the store one time and save it to a global variable. Stupid idea, but it works.
"""
store = zarr.storage.FSStore(fs=FS, url=f"s3://mast/level2/shots/{shot_id}.zarr") # one time only
root = zarr.open_group(store, mode="r")
return store, gr.update(choices=list(set(list(root.group_keys()))), value=None, interactive=True)
def list_signals(store, group_name):
"""Return all signals in a list."""
print("###########################")
print(group_name)
print("###########################")
group = xr.open_zarr(store, group=group_name)
print("###########################")
print(list(group.data_vars.keys()))
print("###########################")
return gr.update(choices = list(group.data_vars.keys()), value=None, interactive=True)
def plot_predefined(shot_id, templates):
if not shot_id:
return ["Error: no shot ID entered"]
store = zarr.storage.FSStore(fs=FS, url=f"s3://mast/level2/shots/{shot_id}.zarr")
# store = zarr.storage.FSStore(fs=FS, url=f"s3://mast/level2/shots/{shot_id}.zarr")
figs = []
# start clean each time
# for t in templates:
print("###########################")
print(templates)
print(store)
print("###########################")
updates = []
for t in templates:
plt.figure(figsize=(8, 4))
if "Plasma current (ip)" == t:
group = xr.open_zarr(store, group="summary")
group["ip"].plot()
plt.title("Plasma current (ip)")
elif "Core density (ne_core)" in templates:
group = xr.open_zarr(store, group="thomson_scattering")
group["n_e_core"].plot()
plt.title("Core electron density")
elif "Core temperature (te_core)" in templates:
group = xr.open_zarr(store, group="thomson_scattering")
group["t_e_core"].plot()
plt.title("Core electron temperature")
elif "OMV" in t:
group = xr.open_zarr(store, group="magnetics")
group["b_field_pol_probe_omv_voltage"].isel(channel=0).plot()
plt.title("OMV Mirnov coil voltage")
plt.grid(True)
plt.tight_layout()
fig = plt.gcf()
plt.close(fig)
updates.append(gr.update(value=fig, visible=True))
# Hide remaining plots
for _ in range(len(updates), MAX_PLOT):
updates.append(gr.update(value=None, visible=False))
return updates
# ---- Logic wiring ----
load_btn.click(list_groups, inputs=shot_box, outputs=[store_state, group_dropdown])
group_dropdown.change(list_signals, inputs=[store_state, group_dropdown], outputs=var_dropdown)
# --------- plot template button and plots ------
# With
# 5 pre-created plot placeholders
with gr.Accordion("Plots", open=True):
plots = [gr.Plot(visible=False) for _ in range(5)]
plot_btn.click(plot_predefined, inputs=[shot_box, plot_templates], outputs=plots)
# plot_btn.click(plot_signal, inputs=[shot_box, group_dropdown, var_dropdown], outputs=plot_output)
# var_select = gr.Dropdown(label="Select signals", multiselect=True)
# plot_btn = gr.Button("Plot Selected")
# plot_out = gr.Plot()
# def load_vars(shot_id):
# store = zarr.storage.FSStore(fs=FS, url=f"s3://mast/level2/shots/{shot_id}.zarr")
# group = xr.open_zarr(store, group='summary')
# return gr.update(choices=list(group.data_vars.keys()))
# def plot_selected(shot_id, vars):
# return make_plots(shot_id, vars)
# load_btn.click(load_vars, inputs=shot_id, outputs=var_select)
# plot_btn.click(plot_selected, inputs=[shot_id, var_select], outputs=plot_out)
# demo.launch()
demo.launch(debug=True, show_error=True)