jbubble-sweep / sweep_dashboard.py
callumtilbury's picture
pulse preview to match pulse cycles, rather than fixed at 3
a53f322
"""Streamlit dashboard for 2D parameter sweeps with jbubble.
This app lets you pick two parameters to sweep, control ranges/resolution,
and view the resulting heatmap.
"""
from typing import Sequence, Union, Dict, Any, List, Tuple
import itertools
import jax
import numpy as np
import plotly.graph_objects as go
import streamlit as st
from jbubble import (
Units,
SaveSpec,
arrays_from_result,
run_simulation,
Bubble,
Pulse,
Sine,
Sawtooth,
Triangle,
Quadratic,
NegativeQuadratic,
Asymmetrical,
SlantedSine,
Square,
TimeDomainSquare,
TimeDomainSawtooth,
TimeDomainTriangle,
)
UNITS = Units()
SAVE_SPEC = SaveSpec(num_samples=800)
AVAILABLE_SHAPES = [
Sine(),
Sawtooth(),
Triangle(),
Quadratic(),
NegativeQuadratic(),
Asymmetrical(),
SlantedSine(),
Square(),
TimeDomainSquare(),
TimeDomainSawtooth(),
TimeDomainTriangle(),
]
SHAPE_MAP = {shape.name: shape for shape in AVAILABLE_SHAPES}
MIN_FREQ_KHZ = 100.0
MAX_FREQ_KHZ = 1500.0
PRESSURE_LIMIT_KPA = 500.0
RADIUS_MIN_UM = 1.0
RADIUS_MAX_UM = 12.0
TIME_MAX_US = 25.0
RADIUS_AXIS_MAX_UM = 20.0
DEFAULTS: Dict[str, Any] = {
"pulse_shape": "sine",
"apply_hann": False,
"freq": 750.0,
"pressure": 200.0,
"radius": 3.0,
"cycles": 5,
"r_buckle_fraction": 0.99,
"gamma": 1.07,
"chi": 0.38,
"mu_L": 0.00089,
"kappa_s": 2.4e-9,
"rho_L": 1000.0,
"c_L": 1498.0,
"p_amb": 101300.0,
"sigma_L": 0.072,
"vdw_divisor": 5.61,
}
PARAM_SPECS: Dict[str, Dict[str, Any]] = {
"radius": {"label": "Equilibrium radius (μm)", "min": RADIUS_MIN_UM, "max": RADIUS_MAX_UM, "step": 0.1, "fmt": "%.2f"},
"freq": {"label": "Frequency (kHz)", "min": MIN_FREQ_KHZ, "max": MAX_FREQ_KHZ, "step": 10.0, "fmt": "%.1f"},
"pressure": {"label": "Pressure amplitude (kPa)", "min": 0.0, "max": PRESSURE_LIMIT_KPA, "step": 10.0, "fmt": "%.1f"},
"cycles": {"label": "Pulse cycles", "min": 1, "max": 12, "step": 1, "fmt": "%d"},
"r_buckle_fraction": {"label": "R_buckle fraction", "min": 0.5, "max": 1.1, "step": 0.01, "fmt": "%.3f"},
"gamma": {"label": "Polytropic index (gamma)", "min": 1.0, "max": 1.5, "step": 0.01, "fmt": "%.3f"},
"chi": {"label": "Shell elasticity (chi) [N/m]", "min": 0.0, "max": 1.0, "step": 0.01, "fmt": "%.3f"},
"mu_L": {"label": "Liquid viscosity (mu_L) [Pa.s]", "min": 0.0001, "max": 0.005, "step": 0.0001, "fmt": "%.5f"},
"kappa_s": {"label": "Shell viscosity (kappa_s) [kg/s]", "min": 1e-10, "max": 5e-8, "step": 1e-10, "fmt": "%.1e"},
"rho_L": {"label": "Liquid density (rho_L) [kg/m^3]", "min": 900.0, "max": 1100.0, "step": 5.0, "fmt": "%.1f"},
"c_L": {"label": "Speed of sound (c_L) [m/s]", "min": 1400.0, "max": 1600.0, "step": 5.0, "fmt": "%.1f"},
"p_amb": {"label": "Ambient pressure (P_amb) [Pa]", "min": 80000.0, "max": 140000.0, "step": 500.0, "fmt": "%.1f"},
"sigma_L": {"label": "Surface tension (sigma_L) [N/m]", "min": 0.01, "max": 0.1, "step": 0.001, "fmt": "%.4f"},
"vdw_divisor": {"label": "Van der Waals divisor", "min": 3.0, "max": 8.0, "step": 0.1, "fmt": "%.2f"},
}
ArrayLike = Union[Sequence[float], np.ndarray]
def render_pulse_preview(shape_name: str, apply_hann: bool, cycles: int = 3) -> go.Figure:
"""Generate a small preview of the pulse shape."""
shape = SHAPE_MAP[shape_name]
# Use normalized time with freq=1 for visualization
freq = 1.0
initial_time = 0.0
phase = 0.0
t = np.linspace(0, cycles, 500)
# Generate pulse waveform using the shape's __call__ method
waveform = np.array([float(shape(ti, freq, phase, initial_time)) for ti in t])
# Apply Hann window if selected
if apply_hann:
hann = np.sin(np.pi * t / cycles) ** 2
waveform = waveform * hann
fig = go.Figure()
fig.add_trace(go.Scatter(
x=t,
y=waveform,
mode="lines",
line=dict(color="#45FFE9", width=2),
showlegend=False,
))
fig.update_layout(
template="plotly_white",
height=70,
margin=dict(l=5, r=5, t=5, b=5),
xaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
yaxis=dict(showticklabels=False, showgrid=False, zeroline=True, zerolinecolor="#666", zerolinewidth=1),
)
return fig
def init_session_state() -> None:
if "sweep_store" not in st.session_state:
st.session_state["sweep_store"] = None
for k, v in DEFAULTS.items():
if k not in st.session_state:
st.session_state[k] = v
def reset_bubble_defaults() -> None:
bubble_keys = [
"radius", "r_buckle_fraction", "gamma", "chi", "mu_L",
"kappa_s", "rho_L", "c_L", "p_amb", "sigma_L", "vdw_divisor",
]
for key in bubble_keys:
st.session_state[key] = DEFAULTS[key]
@st.cache_resource(show_spinner=False)
def _jitted_simulator():
return jax.jit(run_simulation)
def run_single_sim(params: Dict[str, Any]):
R0 = params["radius"] * 1e-6
bubble = Bubble(
R0=R0,
R_buckle=params["r_buckle_fraction"] * R0,
gamma=params["gamma"],
chi=params["chi"],
mu_L=params["mu_L"],
kappa_s=params["kappa_s"],
rho_L=params["rho_L"],
c_L=params["c_L"],
P_amb=params["p_amb"],
sigma_L=params["sigma_L"],
vdw_divisor=params["vdw_divisor"],
)
pulse = Pulse(
shape=SHAPE_MAP[params["pulse_shape"]],
freq=params["freq"] * 1e3,
pressure=params["pressure"] * 1e3,
cycle_num=int(params["cycles"]),
initial_time=1e-6,
apply_hann=bool(params["apply_hann"]),
)
result = _jitted_simulator()(
bubble=bubble,
pulse=pulse,
units=UNITS,
save_spec=SAVE_SPEC,
)
arrays = arrays_from_result(result)
expansion_ratio = float(np.max(arrays.radius_um) / params["radius"])
return result, arrays, expansion_ratio
def render_axis_controls(axis_key: str, key_prefix: str) -> Tuple[float, float]:
spec = PARAM_SPECS[axis_key]
min_v = spec["min"]
max_v = spec["max"]
step = spec["step"]
label = spec["label"]
range_default = (float(min_v), float(max_v))
return st.slider(f"{label}", min_value=float(min_v), max_value=float(max_v), value=range_default, step=float(step), key=f"{key_prefix}_{axis_key}")
def _build_batched_inputs(x_axis: str, y_axis: str, x_values: np.ndarray, y_values: np.ndarray, base_params: Dict[str, Any]):
"""Build batched JAX arrays for all parameter combinations."""
import jax.numpy as jnp
# Create meshgrid of x and y values
x_mesh, y_mesh = np.meshgrid(x_values, y_values)
x_flat = x_mesh.flatten()
y_flat = y_mesh.flatten()
n_sims = len(x_flat)
# Build arrays for each parameter, broadcasting base values or using sweep values
param_arrays = {}
for key in ["radius", "r_buckle_fraction", "gamma", "chi", "mu_L",
"kappa_s", "rho_L", "c_L", "p_amb", "sigma_L", "vdw_divisor",
"freq", "pressure", "cycles"]:
if key == x_axis:
param_arrays[key] = jnp.array(x_flat)
elif key == y_axis:
param_arrays[key] = jnp.array(y_flat)
else:
param_arrays[key] = jnp.full(n_sims, base_params[key])
return param_arrays, x_mesh.shape
def _run_batched_sims(param_arrays: Dict[str, Any], base_params: Dict[str, Any], cycles_value: int, chunk_size: int = 0, progress_callback=None):
"""Run simulations in a vectorized manner using vmap.
Args:
param_arrays: Dictionary of parameter arrays
base_params: Base parameter dictionary for static values
cycles_value: Static cycle count
chunk_size: If > 0, process in chunks for progress updates. If 0, run all at once.
progress_callback: Optional callback(completed, total) for progress updates
"""
import jax.numpy as jnp
# Convert parameter arrays to proper units
R0_arr = param_arrays["radius"] * 1e-6
freq_arr = param_arrays["freq"] * 1e3
pressure_arr = param_arrays["pressure"] * 1e3
# Get static parameters
shape = SHAPE_MAP[base_params["pulse_shape"]]
apply_hann = bool(base_params["apply_hann"])
def single_sim(R0, r_buckle_frac, gamma, chi, mu_L, kappa_s, rho_L, c_L, p_amb, sigma_L, vdw_divisor, freq, pressure):
bubble = Bubble(
R0=R0,
R_buckle=r_buckle_frac * R0,
gamma=gamma,
chi=chi,
mu_L=mu_L,
kappa_s=kappa_s,
rho_L=rho_L,
c_L=c_L,
P_amb=p_amb,
sigma_L=sigma_L,
vdw_divisor=vdw_divisor,
)
pulse = Pulse(
shape=shape,
freq=freq,
pressure=pressure,
cycle_num=cycles_value, # Use static value from closure
initial_time=1e-6,
apply_hann=apply_hann,
)
# Return the raw result - don't call arrays_from_result inside vmap
return run_simulation(
bubble=bubble,
pulse=pulse,
units=UNITS,
save_spec=SAVE_SPEC,
)
# Vectorize and JIT compile the simulation function
batched_sim = jax.jit(jax.vmap(single_sim))
n_sims = len(R0_arr)
# If no chunking requested or chunk_size >= total, run all at once
if chunk_size <= 0 or chunk_size >= n_sims:
results = batched_sim(
R0_arr,
param_arrays["r_buckle_fraction"],
param_arrays["gamma"],
param_arrays["chi"],
param_arrays["mu_L"],
param_arrays["kappa_s"],
param_arrays["rho_L"],
param_arrays["c_L"],
param_arrays["p_amb"],
param_arrays["sigma_L"],
param_arrays["vdw_divisor"],
freq_arr,
pressure_arr,
)
max_radii = results.radius.max(axis=-1)
expansion_ratios = max_radii / results.bubble.R0
return np.array(expansion_ratios)
# Process in chunks for progress updates
all_expansion_ratios = []
for start_idx in range(0, n_sims, chunk_size):
end_idx = min(start_idx + chunk_size, n_sims)
# Slice arrays for this chunk
results = batched_sim(
R0_arr[start_idx:end_idx],
param_arrays["r_buckle_fraction"][start_idx:end_idx],
param_arrays["gamma"][start_idx:end_idx],
param_arrays["chi"][start_idx:end_idx],
param_arrays["mu_L"][start_idx:end_idx],
param_arrays["kappa_s"][start_idx:end_idx],
param_arrays["rho_L"][start_idx:end_idx],
param_arrays["c_L"][start_idx:end_idx],
param_arrays["p_amb"][start_idx:end_idx],
param_arrays["sigma_L"][start_idx:end_idx],
param_arrays["vdw_divisor"][start_idx:end_idx],
freq_arr[start_idx:end_idx],
pressure_arr[start_idx:end_idx],
)
# Block until this chunk is done (needed for accurate progress)
max_radii = results.radius.max(axis=-1)
expansion_ratios = max_radii / results.bubble.R0
chunk_results = np.array(expansion_ratios)
all_expansion_ratios.append(chunk_results)
if progress_callback:
progress_callback(end_idx, n_sims)
return np.concatenate(all_expansion_ratios)
def sweep_grid(x_axis: str, y_axis: str, x_values: np.ndarray, y_values: np.ndarray, base_params: Dict[str, Any]):
total = len(x_values) * len(y_values)
progress = st.progress(0.0)
status = st.empty()
# Build batched inputs
param_arrays, grid_shape = _build_batched_inputs(x_axis, y_axis, x_values, y_values, base_params)
# Get cycles as a static int value (can't be vmapped over)
cycles_value = int(base_params["cycles"])
# Use chunks for progress bar - each chunk is a row of the grid
# This gives good granularity while keeping vmap efficiency
chunk_size = len(x_values) # One row at a time
def update_progress(completed, total_sims):
progress.progress(completed / total_sims)
status.write(f"Running simulations... {completed}/{total_sims}")
status.write(f"Compiling simulation (first run may be slow)...")
# Run all simulations in parallel with vmap, with progress updates
expansion_ratios = _run_batched_sims(
param_arrays, base_params, cycles_value,
chunk_size=chunk_size,
progress_callback=update_progress
)
# Reshape back to grid
grid = expansion_ratios.reshape(grid_shape)
status.empty()
progress.empty()
return grid
init_session_state()
st.set_page_config(page_title="jbubble sweep", layout="wide")
axis_options = list(PARAM_SPECS.keys())
with st.sidebar:
st.image("sweep.svg", width='stretch')
sweep_button = st.button("Run", type="primary", width='stretch')
col_xy_A = st.columns(2)
with col_xy_A[0]:
x_axis = st.selectbox("X axis", axis_options, index=axis_options.index("radius"))
with col_xy_A[1]:
y_axis = st.selectbox("Y axis", axis_options, index=axis_options.index("freq"))
if x_axis == y_axis:
st.error("Choose two different axes to sweep.")
x_range = render_axis_controls(x_axis, "x_range")
y_range = x_range # Placeholder, won't be used
else:
x_range = render_axis_controls(x_axis, "x_range")
y_range = render_axis_controls(y_axis, "y_range")
col_xy_B = st.columns(2)
with col_xy_B[0]:
x_points = st.slider("X resolution", 5, 200, 50, step=1)
with col_xy_B[1]:
y_points = st.slider("Y resolution", 5, 200, 50, step=1)
st.markdown("---")
param_inputs: Dict[str, Any] = {}
# Pulse parameters
pulse_params = ["freq", "pressure", "cycles"]
with st.expander("Pulse", expanded=False):
pulse_shape = st.selectbox("Pulse shape", list(SHAPE_MAP.keys()), index=list(SHAPE_MAP.keys()).index(DEFAULTS["pulse_shape"]))
apply_hann = st.checkbox("Hann window", value=DEFAULTS["apply_hann"])
param_inputs["pulse_shape"] = pulse_shape
param_inputs["apply_hann"] = apply_hann
# Pulse shape preview
assert pulse_shape is not None
if "cycles" in (x_axis, y_axis):
preview_cycles = 3
else:
preview_cycles = int(st.session_state.get("cycles", DEFAULTS["cycles"]))
st.plotly_chart(render_pulse_preview(pulse_shape, apply_hann, preview_cycles), width='stretch', config={"displayModeBar": False})
for key in pulse_params:
if key in (x_axis, y_axis):
continue
spec = PARAM_SPECS[key]
default_val = DEFAULTS[key]
min_v = spec["min"]
max_v = spec["max"]
step = spec["step"]
if isinstance(default_val, int):
val = st.slider(spec["label"], min_value=int(min_v), max_value=int(max_v), step=int(step), key=key)
else:
val = st.slider(spec["label"], min_value=float(min_v), max_value=float(max_v), step=float(step), key=key)
param_inputs[key] = val
# Bubble parameters
with st.expander("Bubble", expanded=False):
st.button("Reset to Defaults", on_click=reset_bubble_defaults)
param_inputs["radius"] = st.number_input("Equilibrium radius (μm)", format="%.2f", step=0.1, key="radius", disabled="radius" in (x_axis, y_axis))
param_inputs["r_buckle_fraction"] = st.number_input("R_buckle fraction", format="%.4f", step=0.01, key="r_buckle_fraction", disabled="r_buckle_fraction" in (x_axis, y_axis))
param_inputs["gamma"] = st.number_input("Polytropic index (gamma)", format="%.4f", step=0.01, key="gamma", disabled="gamma" in (x_axis, y_axis))
param_inputs["chi"] = st.number_input("Shell elasticity (chi) [N/m]", format="%.4f", step=0.01, key="chi", disabled="chi" in (x_axis, y_axis))
param_inputs["mu_L"] = st.number_input("Liquid viscosity (mu_L) [Pa.s]", format="%.6f", step=0.00001, key="mu_L", disabled="mu_L" in (x_axis, y_axis))
param_inputs["kappa_s"] = st.number_input("Shell viscosity (kappa_s) [kg/s]", format="%.3e", step=1e-10, key="kappa_s", disabled="kappa_s" in (x_axis, y_axis))
param_inputs["rho_L"] = st.number_input("Liquid density (rho_L) [kg/m^3]", format="%.1f", step=10.0, key="rho_L", disabled="rho_L" in (x_axis, y_axis))
param_inputs["c_L"] = st.number_input("Speed of sound (c_L) [m/s]", format="%.1f", step=10.0, key="c_L", disabled="c_L" in (x_axis, y_axis))
param_inputs["p_amb"] = st.number_input("Ambient pressure (P_amb) [Pa]", format="%.1f", step=100.0, key="p_amb", disabled="p_amb" in (x_axis, y_axis))
param_inputs["sigma_L"] = st.number_input("Surface tension (sigma_L) [N/m]", format="%.4f", step=0.001, key="sigma_L", disabled="sigma_L" in (x_axis, y_axis))
param_inputs["vdw_divisor"] = st.number_input("Van der Waals divisor", format="%.2f", step=0.1, key="vdw_divisor", disabled="vdw_divisor" in (x_axis, y_axis))
st.markdown("---")
invert_cols = st.columns(2)
with invert_cols[0]:
invert_x_axis = st.checkbox("Invert X", value=False)
with invert_cols[1]:
invert_y_axis = st.checkbox("Invert Y", value=True)
# with plot_col:
placeholder_heatmap = st.empty()
if sweep_button and x_axis != y_axis:
x_values = np.linspace(x_range[0], x_range[1], num=x_points)
y_values = np.linspace(y_range[0], y_range[1], num=y_points)
base_params = dict(param_inputs)
base_params[x_axis] = DEFAULTS.get(x_axis, x_range[0])
base_params[y_axis] = DEFAULTS.get(y_axis, y_range[0])
grid = sweep_grid(x_axis, y_axis, x_values, y_values, base_params)
st.session_state["sweep_store"] = {
"x_axis": x_axis,
"y_axis": y_axis,
"x_values": x_values,
"y_values": y_values,
"x_range": x_range,
"y_range": y_range,
"x_points": x_points,
"y_points": y_points,
"grid": grid,
"base_params": base_params,
}
store = st.session_state.get("sweep_store")
if store:
# Check if current settings match stored settings
current_params = dict(param_inputs)
current_params[x_axis] = DEFAULTS.get(x_axis, x_range[0])
current_params[y_axis] = DEFAULTS.get(y_axis, y_range[0])
is_stale = (
store["x_axis"] != x_axis or
store["y_axis"] != y_axis or
store["x_range"] != x_range or
store["y_range"] != y_range or
store["x_points"] != x_points or
store["y_points"] != y_points or
store["base_params"] != current_params
)
# Use grayscale colorscale when stale
colorscale = "Viridis"
fig = go.Figure(
data=go.Heatmap(
x=store["x_values"],
y=store["y_values"],
z=store["grid"],
colorscale=colorscale,
colorbar=dict(title="Max expansion Rmax/R0"),
)
)
fig.update_layout(
template="plotly_white",
height=800,
width=450,
margin=dict(l=60, r=10, t=30, b=40),
xaxis_title=PARAM_SPECS[store["x_axis"]]["label"],
yaxis_title=PARAM_SPECS[store["y_axis"]]["label"],
)
if invert_x_axis:
fig.update_xaxes(autorange="reversed")
if invert_y_axis:
fig.update_yaxes(autorange="reversed")
placeholder_heatmap.plotly_chart(fig, width='stretch')
if is_stale:
st.warning("⚠️ Parameters changed — press **Run** to update")
elif not store:
placeholder_heatmap.info("Configure axes and press Run to compute the heatmap.")