Spaces:
Sleeping
Sleeping
| """Streamlit dashboard for 2D parameter sweeps with jbubble. | |
| This app lets you pick two parameters to sweep, control ranges/resolution, | |
| and view the resulting heatmap. | |
| """ | |
| from typing import Sequence, Union, Dict, Any, List, Tuple | |
| import itertools | |
| import jax | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| from jbubble import ( | |
| Units, | |
| SaveSpec, | |
| arrays_from_result, | |
| run_simulation, | |
| Bubble, | |
| Pulse, | |
| Sine, | |
| Sawtooth, | |
| Triangle, | |
| Quadratic, | |
| NegativeQuadratic, | |
| Asymmetrical, | |
| SlantedSine, | |
| Square, | |
| TimeDomainSquare, | |
| TimeDomainSawtooth, | |
| TimeDomainTriangle, | |
| ) | |
| UNITS = Units() | |
| SAVE_SPEC = SaveSpec(num_samples=800) | |
| AVAILABLE_SHAPES = [ | |
| Sine(), | |
| Sawtooth(), | |
| Triangle(), | |
| Quadratic(), | |
| NegativeQuadratic(), | |
| Asymmetrical(), | |
| SlantedSine(), | |
| Square(), | |
| TimeDomainSquare(), | |
| TimeDomainSawtooth(), | |
| TimeDomainTriangle(), | |
| ] | |
| SHAPE_MAP = {shape.name: shape for shape in AVAILABLE_SHAPES} | |
| MIN_FREQ_KHZ = 100.0 | |
| MAX_FREQ_KHZ = 1500.0 | |
| PRESSURE_LIMIT_KPA = 500.0 | |
| RADIUS_MIN_UM = 1.0 | |
| RADIUS_MAX_UM = 12.0 | |
| TIME_MAX_US = 25.0 | |
| RADIUS_AXIS_MAX_UM = 20.0 | |
| DEFAULTS: Dict[str, Any] = { | |
| "pulse_shape": "sine", | |
| "apply_hann": False, | |
| "freq": 750.0, | |
| "pressure": 200.0, | |
| "radius": 3.0, | |
| "cycles": 5, | |
| "r_buckle_fraction": 0.99, | |
| "gamma": 1.07, | |
| "chi": 0.38, | |
| "mu_L": 0.00089, | |
| "kappa_s": 2.4e-9, | |
| "rho_L": 1000.0, | |
| "c_L": 1498.0, | |
| "p_amb": 101300.0, | |
| "sigma_L": 0.072, | |
| "vdw_divisor": 5.61, | |
| } | |
| PARAM_SPECS: Dict[str, Dict[str, Any]] = { | |
| "radius": {"label": "Equilibrium radius (μm)", "min": RADIUS_MIN_UM, "max": RADIUS_MAX_UM, "step": 0.1, "fmt": "%.2f"}, | |
| "freq": {"label": "Frequency (kHz)", "min": MIN_FREQ_KHZ, "max": MAX_FREQ_KHZ, "step": 10.0, "fmt": "%.1f"}, | |
| "pressure": {"label": "Pressure amplitude (kPa)", "min": 0.0, "max": PRESSURE_LIMIT_KPA, "step": 10.0, "fmt": "%.1f"}, | |
| "cycles": {"label": "Pulse cycles", "min": 1, "max": 12, "step": 1, "fmt": "%d"}, | |
| "r_buckle_fraction": {"label": "R_buckle fraction", "min": 0.5, "max": 1.1, "step": 0.01, "fmt": "%.3f"}, | |
| "gamma": {"label": "Polytropic index (gamma)", "min": 1.0, "max": 1.5, "step": 0.01, "fmt": "%.3f"}, | |
| "chi": {"label": "Shell elasticity (chi) [N/m]", "min": 0.0, "max": 1.0, "step": 0.01, "fmt": "%.3f"}, | |
| "mu_L": {"label": "Liquid viscosity (mu_L) [Pa.s]", "min": 0.0001, "max": 0.005, "step": 0.0001, "fmt": "%.5f"}, | |
| "kappa_s": {"label": "Shell viscosity (kappa_s) [kg/s]", "min": 1e-10, "max": 5e-8, "step": 1e-10, "fmt": "%.1e"}, | |
| "rho_L": {"label": "Liquid density (rho_L) [kg/m^3]", "min": 900.0, "max": 1100.0, "step": 5.0, "fmt": "%.1f"}, | |
| "c_L": {"label": "Speed of sound (c_L) [m/s]", "min": 1400.0, "max": 1600.0, "step": 5.0, "fmt": "%.1f"}, | |
| "p_amb": {"label": "Ambient pressure (P_amb) [Pa]", "min": 80000.0, "max": 140000.0, "step": 500.0, "fmt": "%.1f"}, | |
| "sigma_L": {"label": "Surface tension (sigma_L) [N/m]", "min": 0.01, "max": 0.1, "step": 0.001, "fmt": "%.4f"}, | |
| "vdw_divisor": {"label": "Van der Waals divisor", "min": 3.0, "max": 8.0, "step": 0.1, "fmt": "%.2f"}, | |
| } | |
| ArrayLike = Union[Sequence[float], np.ndarray] | |
| def render_pulse_preview(shape_name: str, apply_hann: bool, cycles: int = 3) -> go.Figure: | |
| """Generate a small preview of the pulse shape.""" | |
| shape = SHAPE_MAP[shape_name] | |
| # Use normalized time with freq=1 for visualization | |
| freq = 1.0 | |
| initial_time = 0.0 | |
| phase = 0.0 | |
| t = np.linspace(0, cycles, 500) | |
| # Generate pulse waveform using the shape's __call__ method | |
| waveform = np.array([float(shape(ti, freq, phase, initial_time)) for ti in t]) | |
| # Apply Hann window if selected | |
| if apply_hann: | |
| hann = np.sin(np.pi * t / cycles) ** 2 | |
| waveform = waveform * hann | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=t, | |
| y=waveform, | |
| mode="lines", | |
| line=dict(color="#45FFE9", width=2), | |
| showlegend=False, | |
| )) | |
| fig.update_layout( | |
| template="plotly_white", | |
| height=70, | |
| margin=dict(l=5, r=5, t=5, b=5), | |
| xaxis=dict(showticklabels=False, showgrid=False, zeroline=False), | |
| yaxis=dict(showticklabels=False, showgrid=False, zeroline=True, zerolinecolor="#666", zerolinewidth=1), | |
| ) | |
| return fig | |
| def init_session_state() -> None: | |
| if "sweep_store" not in st.session_state: | |
| st.session_state["sweep_store"] = None | |
| for k, v in DEFAULTS.items(): | |
| if k not in st.session_state: | |
| st.session_state[k] = v | |
| def reset_bubble_defaults() -> None: | |
| bubble_keys = [ | |
| "radius", "r_buckle_fraction", "gamma", "chi", "mu_L", | |
| "kappa_s", "rho_L", "c_L", "p_amb", "sigma_L", "vdw_divisor", | |
| ] | |
| for key in bubble_keys: | |
| st.session_state[key] = DEFAULTS[key] | |
| def _jitted_simulator(): | |
| return jax.jit(run_simulation) | |
| def run_single_sim(params: Dict[str, Any]): | |
| R0 = params["radius"] * 1e-6 | |
| bubble = Bubble( | |
| R0=R0, | |
| R_buckle=params["r_buckle_fraction"] * R0, | |
| gamma=params["gamma"], | |
| chi=params["chi"], | |
| mu_L=params["mu_L"], | |
| kappa_s=params["kappa_s"], | |
| rho_L=params["rho_L"], | |
| c_L=params["c_L"], | |
| P_amb=params["p_amb"], | |
| sigma_L=params["sigma_L"], | |
| vdw_divisor=params["vdw_divisor"], | |
| ) | |
| pulse = Pulse( | |
| shape=SHAPE_MAP[params["pulse_shape"]], | |
| freq=params["freq"] * 1e3, | |
| pressure=params["pressure"] * 1e3, | |
| cycle_num=int(params["cycles"]), | |
| initial_time=1e-6, | |
| apply_hann=bool(params["apply_hann"]), | |
| ) | |
| result = _jitted_simulator()( | |
| bubble=bubble, | |
| pulse=pulse, | |
| units=UNITS, | |
| save_spec=SAVE_SPEC, | |
| ) | |
| arrays = arrays_from_result(result) | |
| expansion_ratio = float(np.max(arrays.radius_um) / params["radius"]) | |
| return result, arrays, expansion_ratio | |
| def render_axis_controls(axis_key: str, key_prefix: str) -> Tuple[float, float]: | |
| spec = PARAM_SPECS[axis_key] | |
| min_v = spec["min"] | |
| max_v = spec["max"] | |
| step = spec["step"] | |
| label = spec["label"] | |
| range_default = (float(min_v), float(max_v)) | |
| return st.slider(f"{label}", min_value=float(min_v), max_value=float(max_v), value=range_default, step=float(step), key=f"{key_prefix}_{axis_key}") | |
| def _build_batched_inputs(x_axis: str, y_axis: str, x_values: np.ndarray, y_values: np.ndarray, base_params: Dict[str, Any]): | |
| """Build batched JAX arrays for all parameter combinations.""" | |
| import jax.numpy as jnp | |
| # Create meshgrid of x and y values | |
| x_mesh, y_mesh = np.meshgrid(x_values, y_values) | |
| x_flat = x_mesh.flatten() | |
| y_flat = y_mesh.flatten() | |
| n_sims = len(x_flat) | |
| # Build arrays for each parameter, broadcasting base values or using sweep values | |
| param_arrays = {} | |
| for key in ["radius", "r_buckle_fraction", "gamma", "chi", "mu_L", | |
| "kappa_s", "rho_L", "c_L", "p_amb", "sigma_L", "vdw_divisor", | |
| "freq", "pressure", "cycles"]: | |
| if key == x_axis: | |
| param_arrays[key] = jnp.array(x_flat) | |
| elif key == y_axis: | |
| param_arrays[key] = jnp.array(y_flat) | |
| else: | |
| param_arrays[key] = jnp.full(n_sims, base_params[key]) | |
| return param_arrays, x_mesh.shape | |
| def _run_batched_sims(param_arrays: Dict[str, Any], base_params: Dict[str, Any], cycles_value: int, chunk_size: int = 0, progress_callback=None): | |
| """Run simulations in a vectorized manner using vmap. | |
| Args: | |
| param_arrays: Dictionary of parameter arrays | |
| base_params: Base parameter dictionary for static values | |
| cycles_value: Static cycle count | |
| chunk_size: If > 0, process in chunks for progress updates. If 0, run all at once. | |
| progress_callback: Optional callback(completed, total) for progress updates | |
| """ | |
| import jax.numpy as jnp | |
| # Convert parameter arrays to proper units | |
| R0_arr = param_arrays["radius"] * 1e-6 | |
| freq_arr = param_arrays["freq"] * 1e3 | |
| pressure_arr = param_arrays["pressure"] * 1e3 | |
| # Get static parameters | |
| shape = SHAPE_MAP[base_params["pulse_shape"]] | |
| apply_hann = bool(base_params["apply_hann"]) | |
| def single_sim(R0, r_buckle_frac, gamma, chi, mu_L, kappa_s, rho_L, c_L, p_amb, sigma_L, vdw_divisor, freq, pressure): | |
| bubble = Bubble( | |
| R0=R0, | |
| R_buckle=r_buckle_frac * R0, | |
| gamma=gamma, | |
| chi=chi, | |
| mu_L=mu_L, | |
| kappa_s=kappa_s, | |
| rho_L=rho_L, | |
| c_L=c_L, | |
| P_amb=p_amb, | |
| sigma_L=sigma_L, | |
| vdw_divisor=vdw_divisor, | |
| ) | |
| pulse = Pulse( | |
| shape=shape, | |
| freq=freq, | |
| pressure=pressure, | |
| cycle_num=cycles_value, # Use static value from closure | |
| initial_time=1e-6, | |
| apply_hann=apply_hann, | |
| ) | |
| # Return the raw result - don't call arrays_from_result inside vmap | |
| return run_simulation( | |
| bubble=bubble, | |
| pulse=pulse, | |
| units=UNITS, | |
| save_spec=SAVE_SPEC, | |
| ) | |
| # Vectorize and JIT compile the simulation function | |
| batched_sim = jax.jit(jax.vmap(single_sim)) | |
| n_sims = len(R0_arr) | |
| # If no chunking requested or chunk_size >= total, run all at once | |
| if chunk_size <= 0 or chunk_size >= n_sims: | |
| results = batched_sim( | |
| R0_arr, | |
| param_arrays["r_buckle_fraction"], | |
| param_arrays["gamma"], | |
| param_arrays["chi"], | |
| param_arrays["mu_L"], | |
| param_arrays["kappa_s"], | |
| param_arrays["rho_L"], | |
| param_arrays["c_L"], | |
| param_arrays["p_amb"], | |
| param_arrays["sigma_L"], | |
| param_arrays["vdw_divisor"], | |
| freq_arr, | |
| pressure_arr, | |
| ) | |
| max_radii = results.radius.max(axis=-1) | |
| expansion_ratios = max_radii / results.bubble.R0 | |
| return np.array(expansion_ratios) | |
| # Process in chunks for progress updates | |
| all_expansion_ratios = [] | |
| for start_idx in range(0, n_sims, chunk_size): | |
| end_idx = min(start_idx + chunk_size, n_sims) | |
| # Slice arrays for this chunk | |
| results = batched_sim( | |
| R0_arr[start_idx:end_idx], | |
| param_arrays["r_buckle_fraction"][start_idx:end_idx], | |
| param_arrays["gamma"][start_idx:end_idx], | |
| param_arrays["chi"][start_idx:end_idx], | |
| param_arrays["mu_L"][start_idx:end_idx], | |
| param_arrays["kappa_s"][start_idx:end_idx], | |
| param_arrays["rho_L"][start_idx:end_idx], | |
| param_arrays["c_L"][start_idx:end_idx], | |
| param_arrays["p_amb"][start_idx:end_idx], | |
| param_arrays["sigma_L"][start_idx:end_idx], | |
| param_arrays["vdw_divisor"][start_idx:end_idx], | |
| freq_arr[start_idx:end_idx], | |
| pressure_arr[start_idx:end_idx], | |
| ) | |
| # Block until this chunk is done (needed for accurate progress) | |
| max_radii = results.radius.max(axis=-1) | |
| expansion_ratios = max_radii / results.bubble.R0 | |
| chunk_results = np.array(expansion_ratios) | |
| all_expansion_ratios.append(chunk_results) | |
| if progress_callback: | |
| progress_callback(end_idx, n_sims) | |
| return np.concatenate(all_expansion_ratios) | |
| def sweep_grid(x_axis: str, y_axis: str, x_values: np.ndarray, y_values: np.ndarray, base_params: Dict[str, Any]): | |
| total = len(x_values) * len(y_values) | |
| progress = st.progress(0.0) | |
| status = st.empty() | |
| # Build batched inputs | |
| param_arrays, grid_shape = _build_batched_inputs(x_axis, y_axis, x_values, y_values, base_params) | |
| # Get cycles as a static int value (can't be vmapped over) | |
| cycles_value = int(base_params["cycles"]) | |
| # Use chunks for progress bar - each chunk is a row of the grid | |
| # This gives good granularity while keeping vmap efficiency | |
| chunk_size = len(x_values) # One row at a time | |
| def update_progress(completed, total_sims): | |
| progress.progress(completed / total_sims) | |
| status.write(f"Running simulations... {completed}/{total_sims}") | |
| status.write(f"Compiling simulation (first run may be slow)...") | |
| # Run all simulations in parallel with vmap, with progress updates | |
| expansion_ratios = _run_batched_sims( | |
| param_arrays, base_params, cycles_value, | |
| chunk_size=chunk_size, | |
| progress_callback=update_progress | |
| ) | |
| # Reshape back to grid | |
| grid = expansion_ratios.reshape(grid_shape) | |
| status.empty() | |
| progress.empty() | |
| return grid | |
| init_session_state() | |
| st.set_page_config(page_title="jbubble sweep", layout="wide") | |
| axis_options = list(PARAM_SPECS.keys()) | |
| with st.sidebar: | |
| st.image("sweep.svg", width='stretch') | |
| sweep_button = st.button("Run", type="primary", width='stretch') | |
| col_xy_A = st.columns(2) | |
| with col_xy_A[0]: | |
| x_axis = st.selectbox("X axis", axis_options, index=axis_options.index("radius")) | |
| with col_xy_A[1]: | |
| y_axis = st.selectbox("Y axis", axis_options, index=axis_options.index("freq")) | |
| if x_axis == y_axis: | |
| st.error("Choose two different axes to sweep.") | |
| x_range = render_axis_controls(x_axis, "x_range") | |
| y_range = x_range # Placeholder, won't be used | |
| else: | |
| x_range = render_axis_controls(x_axis, "x_range") | |
| y_range = render_axis_controls(y_axis, "y_range") | |
| col_xy_B = st.columns(2) | |
| with col_xy_B[0]: | |
| x_points = st.slider("X resolution", 5, 200, 50, step=1) | |
| with col_xy_B[1]: | |
| y_points = st.slider("Y resolution", 5, 200, 50, step=1) | |
| st.markdown("---") | |
| param_inputs: Dict[str, Any] = {} | |
| # Pulse parameters | |
| pulse_params = ["freq", "pressure", "cycles"] | |
| with st.expander("Pulse", expanded=False): | |
| pulse_shape = st.selectbox("Pulse shape", list(SHAPE_MAP.keys()), index=list(SHAPE_MAP.keys()).index(DEFAULTS["pulse_shape"])) | |
| apply_hann = st.checkbox("Hann window", value=DEFAULTS["apply_hann"]) | |
| param_inputs["pulse_shape"] = pulse_shape | |
| param_inputs["apply_hann"] = apply_hann | |
| # Pulse shape preview | |
| assert pulse_shape is not None | |
| if "cycles" in (x_axis, y_axis): | |
| preview_cycles = 3 | |
| else: | |
| preview_cycles = int(st.session_state.get("cycles", DEFAULTS["cycles"])) | |
| st.plotly_chart(render_pulse_preview(pulse_shape, apply_hann, preview_cycles), width='stretch', config={"displayModeBar": False}) | |
| for key in pulse_params: | |
| if key in (x_axis, y_axis): | |
| continue | |
| spec = PARAM_SPECS[key] | |
| default_val = DEFAULTS[key] | |
| min_v = spec["min"] | |
| max_v = spec["max"] | |
| step = spec["step"] | |
| if isinstance(default_val, int): | |
| val = st.slider(spec["label"], min_value=int(min_v), max_value=int(max_v), step=int(step), key=key) | |
| else: | |
| val = st.slider(spec["label"], min_value=float(min_v), max_value=float(max_v), step=float(step), key=key) | |
| param_inputs[key] = val | |
| # Bubble parameters | |
| with st.expander("Bubble", expanded=False): | |
| st.button("Reset to Defaults", on_click=reset_bubble_defaults) | |
| param_inputs["radius"] = st.number_input("Equilibrium radius (μm)", format="%.2f", step=0.1, key="radius", disabled="radius" in (x_axis, y_axis)) | |
| param_inputs["r_buckle_fraction"] = st.number_input("R_buckle fraction", format="%.4f", step=0.01, key="r_buckle_fraction", disabled="r_buckle_fraction" in (x_axis, y_axis)) | |
| param_inputs["gamma"] = st.number_input("Polytropic index (gamma)", format="%.4f", step=0.01, key="gamma", disabled="gamma" in (x_axis, y_axis)) | |
| param_inputs["chi"] = st.number_input("Shell elasticity (chi) [N/m]", format="%.4f", step=0.01, key="chi", disabled="chi" in (x_axis, y_axis)) | |
| param_inputs["mu_L"] = st.number_input("Liquid viscosity (mu_L) [Pa.s]", format="%.6f", step=0.00001, key="mu_L", disabled="mu_L" in (x_axis, y_axis)) | |
| param_inputs["kappa_s"] = st.number_input("Shell viscosity (kappa_s) [kg/s]", format="%.3e", step=1e-10, key="kappa_s", disabled="kappa_s" in (x_axis, y_axis)) | |
| param_inputs["rho_L"] = st.number_input("Liquid density (rho_L) [kg/m^3]", format="%.1f", step=10.0, key="rho_L", disabled="rho_L" in (x_axis, y_axis)) | |
| param_inputs["c_L"] = st.number_input("Speed of sound (c_L) [m/s]", format="%.1f", step=10.0, key="c_L", disabled="c_L" in (x_axis, y_axis)) | |
| param_inputs["p_amb"] = st.number_input("Ambient pressure (P_amb) [Pa]", format="%.1f", step=100.0, key="p_amb", disabled="p_amb" in (x_axis, y_axis)) | |
| param_inputs["sigma_L"] = st.number_input("Surface tension (sigma_L) [N/m]", format="%.4f", step=0.001, key="sigma_L", disabled="sigma_L" in (x_axis, y_axis)) | |
| param_inputs["vdw_divisor"] = st.number_input("Van der Waals divisor", format="%.2f", step=0.1, key="vdw_divisor", disabled="vdw_divisor" in (x_axis, y_axis)) | |
| st.markdown("---") | |
| invert_cols = st.columns(2) | |
| with invert_cols[0]: | |
| invert_x_axis = st.checkbox("Invert X", value=False) | |
| with invert_cols[1]: | |
| invert_y_axis = st.checkbox("Invert Y", value=True) | |
| # with plot_col: | |
| placeholder_heatmap = st.empty() | |
| if sweep_button and x_axis != y_axis: | |
| x_values = np.linspace(x_range[0], x_range[1], num=x_points) | |
| y_values = np.linspace(y_range[0], y_range[1], num=y_points) | |
| base_params = dict(param_inputs) | |
| base_params[x_axis] = DEFAULTS.get(x_axis, x_range[0]) | |
| base_params[y_axis] = DEFAULTS.get(y_axis, y_range[0]) | |
| grid = sweep_grid(x_axis, y_axis, x_values, y_values, base_params) | |
| st.session_state["sweep_store"] = { | |
| "x_axis": x_axis, | |
| "y_axis": y_axis, | |
| "x_values": x_values, | |
| "y_values": y_values, | |
| "x_range": x_range, | |
| "y_range": y_range, | |
| "x_points": x_points, | |
| "y_points": y_points, | |
| "grid": grid, | |
| "base_params": base_params, | |
| } | |
| store = st.session_state.get("sweep_store") | |
| if store: | |
| # Check if current settings match stored settings | |
| current_params = dict(param_inputs) | |
| current_params[x_axis] = DEFAULTS.get(x_axis, x_range[0]) | |
| current_params[y_axis] = DEFAULTS.get(y_axis, y_range[0]) | |
| is_stale = ( | |
| store["x_axis"] != x_axis or | |
| store["y_axis"] != y_axis or | |
| store["x_range"] != x_range or | |
| store["y_range"] != y_range or | |
| store["x_points"] != x_points or | |
| store["y_points"] != y_points or | |
| store["base_params"] != current_params | |
| ) | |
| # Use grayscale colorscale when stale | |
| colorscale = "Viridis" | |
| fig = go.Figure( | |
| data=go.Heatmap( | |
| x=store["x_values"], | |
| y=store["y_values"], | |
| z=store["grid"], | |
| colorscale=colorscale, | |
| colorbar=dict(title="Max expansion Rmax/R0"), | |
| ) | |
| ) | |
| fig.update_layout( | |
| template="plotly_white", | |
| height=800, | |
| width=450, | |
| margin=dict(l=60, r=10, t=30, b=40), | |
| xaxis_title=PARAM_SPECS[store["x_axis"]]["label"], | |
| yaxis_title=PARAM_SPECS[store["y_axis"]]["label"], | |
| ) | |
| if invert_x_axis: | |
| fig.update_xaxes(autorange="reversed") | |
| if invert_y_axis: | |
| fig.update_yaxes(autorange="reversed") | |
| placeholder_heatmap.plotly_chart(fig, width='stretch') | |
| if is_stale: | |
| st.warning("⚠️ Parameters changed — press **Run** to update") | |
| elif not store: | |
| placeholder_heatmap.info("Configure axes and press Run to compute the heatmap.") | |