callumtilbury commited on
Commit
26074fb
·
1 Parent(s): 9f0b6fb

init commit of sweep code

Browse files
Files changed (5) hide show
  1. Dockerfile +2 -2
  2. README.md +1 -1
  3. streamlit_app.py +0 -237
  4. jbubble.svg → sweep.svg +6 -6
  5. sweep_dashboard.py +375 -0
Dockerfile CHANGED
@@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y \
13
  RUN mkdir -p -m 0700 ~/.ssh && ssh-keyscan github.com >> ~/.ssh/known_hosts
14
 
15
  COPY jbubble.svg ./
16
- COPY streamlit_app.py ./streamlit_app.py
17
 
18
  # Install jbubble from private repo using SSH secret (for Hugging Face)
19
  RUN --mount=type=secret,id=SSH_KEY,mode=0600,required=true \
@@ -26,4 +26,4 @@ EXPOSE 8501
26
 
27
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
28
 
29
- ENTRYPOINT ["streamlit", "run", "streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
13
  RUN mkdir -p -m 0700 ~/.ssh && ssh-keyscan github.com >> ~/.ssh/known_hosts
14
 
15
  COPY jbubble.svg ./
16
+ COPY sweep_dashboard.py ./sweep_dashboard.py
17
 
18
  # Install jbubble from private repo using SSH secret (for Hugging Face)
19
  RUN --mount=type=secret,id=SSH_KEY,mode=0600,required=true \
 
26
 
27
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
28
 
29
+ ENTRYPOINT ["streamlit", "run", "sweep_dashboard.py", "--server.port=8501", "--server.address=0.0.0.0"]
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: jbubble
3
  emoji: 🫧
4
  colorFrom: blue
5
  colorTo: red
 
1
  ---
2
+ title: jbubble-sweep
3
  emoji: 🫧
4
  colorFrom: blue
5
  colorTo: red
streamlit_app.py DELETED
@@ -1,237 +0,0 @@
1
- """Simple Streamlit interface for jbubble simulations."""
2
-
3
- from typing import Sequence, Union
4
- import jax
5
- import numpy as np
6
- from plotly.subplots import make_subplots
7
- import plotly.graph_objects as go
8
- import streamlit as st
9
-
10
- from jbubble import (
11
- Units,
12
- SaveSpec,
13
- arrays_from_result,
14
- run_simulation,
15
- Bubble,
16
- Pulse,
17
- Sine,
18
- Sawtooth,
19
- Triangle,
20
- Quadratic,
21
- NegativeQuadratic,
22
- Asymmetrical,
23
- SlantedSine,
24
- Square,
25
- TimeDomainSquare,
26
- TimeDomainSawtooth,
27
- TimeDomainTriangle,
28
- )
29
-
30
-
31
- UNITS = Units()
32
- SAVE_SPEC = SaveSpec(num_samples=1000)
33
-
34
- AVAILABLE_SHAPES = [
35
- Sine(),
36
- Sawtooth(),
37
- Triangle(),
38
- Quadratic(),
39
- NegativeQuadratic(),
40
- Asymmetrical(),
41
- SlantedSine(),
42
- Square(),
43
- TimeDomainSquare(),
44
- TimeDomainSawtooth(),
45
- TimeDomainTriangle(),
46
- ]
47
- SHAPE_MAP = {shape.name: shape for shape in AVAILABLE_SHAPES}
48
-
49
- MIN_FREQ_KHZ = 250.0
50
- MAX_FREQ_KHZ = 1500.0
51
- PRESSURE_LIMIT_KPA = 500.0
52
- RADIUS_AXIS_MAX_UM = 15.0
53
- TIME_MAX_US = 15.0
54
- DEFAULTS = {
55
- "pulse_shape": "square",
56
- "freq": (MAX_FREQ_KHZ + MIN_FREQ_KHZ) / 2,
57
- "pressure": PRESSURE_LIMIT_KPA / 2,
58
- "radius": 3.0,
59
- "cycles": 3,
60
- "r_buckle_fraction": 0.99,
61
- "gamma": 1.07,
62
- "chi": 0.38,
63
- "mu_L": 0.00089,
64
- "kappa_s": 2.4e-9,
65
- "rho_L": 1000.0,
66
- "c_L": 1498.0,
67
- "p_amb": 101300.0,
68
- "sigma_L": 0.072,
69
- "vdw_divisor": 5.61,
70
- }
71
-
72
- # Initialize session state with defaults if not present
73
- for key, value in DEFAULTS.items():
74
- if key not in st.session_state:
75
- st.session_state[key] = value
76
-
77
- def reset_defaults():
78
- advanced_keys = [
79
- "r_buckle_fraction",
80
- "gamma",
81
- "chi",
82
- "mu_L",
83
- "kappa_s",
84
- "rho_L",
85
- "c_L",
86
- "p_amb",
87
- "sigma_L",
88
- "vdw_divisor",
89
- ]
90
- for key in advanced_keys:
91
- st.session_state[key] = DEFAULTS[key]
92
-
93
- @st.cache_resource(show_spinner=False)
94
- def _jitted_simulator():
95
- return jax.jit(run_simulation)
96
-
97
-
98
- def simulate(
99
- pulse_shape: str,
100
- apply_hann: bool,
101
- freq_khz: float,
102
- pressure_kpa: float,
103
- radius_um: float,
104
- cycles: int,
105
- r_buckle_fraction: float,
106
- gamma: float,
107
- chi: float,
108
- mu_L: float,
109
- kappa_s: float,
110
- rho_L: float,
111
- c_L: float,
112
- p_amb: float,
113
- sigma_L: float,
114
- vdw_divisor: float,
115
- ):
116
- R0 = radius_um * 1e-6
117
- bubble = Bubble(
118
- R0=R0,
119
- R_buckle=r_buckle_fraction * R0,
120
- gamma=gamma,
121
- chi=chi,
122
- mu_L=mu_L,
123
- kappa_s=kappa_s,
124
- rho_L=rho_L,
125
- c_L=c_L,
126
- P_amb=p_amb,
127
- sigma_L=sigma_L,
128
- vdw_divisor=vdw_divisor,
129
- )
130
- pulse = Pulse(
131
- shape=SHAPE_MAP[pulse_shape],
132
- freq=freq_khz * 1e3,
133
- pressure=pressure_kpa * 1e3,
134
- cycle_num=cycles,
135
- initial_time=1e-6,
136
- apply_hann=apply_hann,
137
- )
138
- result = _jitted_simulator()(
139
- bubble=bubble,
140
- pulse=pulse,
141
- units=UNITS,
142
- save_spec=SAVE_SPEC,
143
- )
144
- return result, arrays_from_result(result)
145
-
146
-
147
- ArrayLike = Union[Sequence[float], np.ndarray]
148
-
149
- def line_trace(x: ArrayLike, y: ArrayLike, *, name: str, color: str | None = None) -> go.Scatter:
150
- return go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=color) if color else None)
151
-
152
-
153
- def _stacked_figure(arrays, marker_idx: int | None):
154
- fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.08)
155
- fig.add_trace(line_trace(arrays.time_us, arrays.pressure_kpa, name="Driving Pressure", color="#45FFE9"), row=1, col=1)
156
- fig.add_trace(line_trace(arrays.time_us, arrays.radius_um, name="Bubble Radius", color="#FFCC33"), row=2, col=1)
157
- fig.update_xaxes(range=(0.0, TIME_MAX_US), title_text="Time (μs)", row=2, col=1)
158
- fig.update_xaxes(range=(0.0, TIME_MAX_US), row=1, col=1, showticklabels=False)
159
- fig.update_yaxes(range=(-PRESSURE_LIMIT_KPA, PRESSURE_LIMIT_KPA), title_text="Pressure (kPa)", row=1, col=1)
160
- fig.update_yaxes(range=(0.0, RADIUS_AXIS_MAX_UM), title_text="Radius (μm)", row=2, col=1)
161
- fig.update_layout(
162
- template="plotly_white",
163
- height=500,
164
- margin=dict(l=40, r=20, t=30, b=60),
165
- legend=dict(
166
- orientation="h",
167
- yanchor="bottom",
168
- y=1.02,
169
- xanchor="right",
170
- x=1.0,
171
- bgcolor="rgba(0,0,0,0)",
172
- ),
173
- )
174
- return fig
175
-
176
-
177
- st.set_page_config(page_title="jbubble", layout="wide")
178
-
179
- with st.sidebar:
180
- st.image("jbubble.svg", width='stretch')
181
- pulse_shape = st.selectbox(
182
- "Pulse shape",
183
- list(SHAPE_MAP.keys()),
184
- key="pulse_shape",
185
- )
186
- apply_hann = st.checkbox("Hann window", value=False, key="apply_hann")
187
- freq = st.slider("Frequency (kHz)", min_value=MIN_FREQ_KHZ, max_value=MAX_FREQ_KHZ, step=10.0, key="freq")
188
- pressure = st.slider("Pressure amplitude (kPa)", min_value=0.0, max_value=PRESSURE_LIMIT_KPA, step=10.0, key="pressure")
189
- radius = st.slider("Equilibrium radius (μm)", min_value=1.0, max_value=5.0, step=0.1, key="radius")
190
- cycles = st.slider("Pulse cycles", min_value=2, max_value=10, step=1, key="cycles")
191
-
192
- with st.expander("Advanced Bubble Parameters"):
193
- st.button("Reset to Defaults", on_click=reset_defaults)
194
- r_buckle_fraction = st.number_input("R_buckle fraction", format="%.4f", step=0.01, key="r_buckle_fraction")
195
- gamma = st.number_input("Polytropic index (gamma)", format="%.4f", step=0.01, key="gamma")
196
- chi = st.number_input("Shell elasticity (chi) [N/m]", format="%.4f", step=0.01, key="chi")
197
- mu_L = st.number_input("Liquid viscosity (mu_L) [Pa.s]", format="%.6f", step=0.00001, key="mu_L")
198
- kappa_s = st.number_input("Shell viscosity (kappa_s) [kg/s]", format="%.3e", step=1e-10, key="kappa_s")
199
- rho_L = st.number_input("Liquid density (rho_L) [kg/m^3]", format="%.1f", step=10.0, key="rho_L")
200
- c_L = st.number_input("Speed of sound (c_L) [m/s]", format="%.1f", step=10.0, key="c_L")
201
- p_amb = st.number_input("Ambient pressure (P_amb) [Pa]", format="%.1f", step=100.0, key="p_amb")
202
- sigma_L = st.number_input("Liquid surface tension (sigma_L) [N/m]", format="%.4f", step=0.001, key="sigma_L")
203
- vdw_divisor = st.number_input("Van der Waals divisor", format="%.2f", step=0.1, key="vdw_divisor")
204
-
205
- result, arrays = simulate(
206
- pulse_shape,
207
- apply_hann,
208
- freq,
209
- pressure,
210
- radius,
211
- cycles,
212
- r_buckle_fraction,
213
- gamma,
214
- chi,
215
- mu_L,
216
- kappa_s,
217
- rho_L,
218
- c_L,
219
- p_amb,
220
- sigma_L,
221
- vdw_divisor,
222
- )
223
-
224
- with st.sidebar:
225
- converged = bool(result.converged)
226
- st.write("")
227
- with st.status("Solver status", state="complete" if converged else "error"):
228
- st.write("✅ Converged" if converged else "🔴 Max steps!")
229
- st.caption("\nBubble dynamics simulation powered by JAX")
230
-
231
- st.plotly_chart(_stacked_figure(arrays, None), width='stretch')
232
- col1, col2, col3, col4 = st.columns(4)
233
- st.write("\n")
234
- col1.metric("Max Radius (μm)", f"{arrays.radius_um.max():.2f}")
235
- col2.metric("Min Radius (μm)", f"{arrays.radius_um.min():.2f}")
236
- col3.metric("Max Expansion Ratio", f"{(arrays.radius_um.max() / (radius)):.2f}")
237
- col4.metric("Collapse Ratio", f"{(arrays.radius_um.min() / (radius)):.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
jbubble.svg → sweep.svg RENAMED
File without changes
sweep_dashboard.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit dashboard for 2D parameter sweeps with jbubble.
2
+
3
+ This app lets you pick two parameters to sweep, control ranges/resolution,
4
+ and view the resulting heatmap.
5
+ """
6
+
7
+ from typing import Sequence, Union, Dict, Any, List, Tuple
8
+ import itertools
9
+
10
+ import jax
11
+ import numpy as np
12
+ import plotly.graph_objects as go
13
+ import streamlit as st
14
+
15
+ from jbubble import (
16
+ Units,
17
+ SaveSpec,
18
+ arrays_from_result,
19
+ run_simulation,
20
+ Bubble,
21
+ Pulse,
22
+ Sine,
23
+ Sawtooth,
24
+ Triangle,
25
+ Quadratic,
26
+ NegativeQuadratic,
27
+ Asymmetrical,
28
+ SlantedSine,
29
+ Square,
30
+ TimeDomainSquare,
31
+ TimeDomainSawtooth,
32
+ TimeDomainTriangle,
33
+ )
34
+
35
+ UNITS = Units()
36
+ SAVE_SPEC = SaveSpec(num_samples=800)
37
+
38
+ AVAILABLE_SHAPES = [
39
+ Sine(),
40
+ Sawtooth(),
41
+ Triangle(),
42
+ Quadratic(),
43
+ NegativeQuadratic(),
44
+ Asymmetrical(),
45
+ SlantedSine(),
46
+ Square(),
47
+ TimeDomainSquare(),
48
+ TimeDomainSawtooth(),
49
+ TimeDomainTriangle(),
50
+ ]
51
+ SHAPE_MAP = {shape.name: shape for shape in AVAILABLE_SHAPES}
52
+
53
+ MIN_FREQ_KHZ = 100.0
54
+ MAX_FREQ_KHZ = 1500.0
55
+ PRESSURE_LIMIT_KPA = 500.0
56
+ RADIUS_MIN_UM = 1.0
57
+ RADIUS_MAX_UM = 12.0
58
+ TIME_MAX_US = 25.0
59
+ RADIUS_AXIS_MAX_UM = 20.0
60
+
61
+ DEFAULTS: Dict[str, Any] = {
62
+ "pulse_shape": "sine",
63
+ "apply_hann": False,
64
+ "freq": 750.0,
65
+ "pressure": 200.0,
66
+ "radius": 3.0,
67
+ "cycles": 5,
68
+ "r_buckle_fraction": 0.99,
69
+ "gamma": 1.07,
70
+ "chi": 0.38,
71
+ "mu_L": 0.00089,
72
+ "kappa_s": 2.4e-9,
73
+ "rho_L": 1000.0,
74
+ "c_L": 1498.0,
75
+ "p_amb": 101300.0,
76
+ "sigma_L": 0.072,
77
+ "vdw_divisor": 5.61,
78
+ }
79
+
80
+ PARAM_SPECS: Dict[str, Dict[str, Any]] = {
81
+ "radius": {"label": "Equilibrium radius (μm)", "min": RADIUS_MIN_UM, "max": RADIUS_MAX_UM, "step": 0.1, "fmt": "%.2f"},
82
+ "freq": {"label": "Frequency (kHz)", "min": MIN_FREQ_KHZ, "max": MAX_FREQ_KHZ, "step": 10.0, "fmt": "%.1f"},
83
+ "pressure": {"label": "Pressure amplitude (kPa)", "min": 0.0, "max": PRESSURE_LIMIT_KPA, "step": 10.0, "fmt": "%.1f"},
84
+ "cycles": {"label": "Pulse cycles", "min": 1, "max": 12, "step": 1, "fmt": "%d"},
85
+ "r_buckle_fraction": {"label": "R_buckle fraction", "min": 0.5, "max": 1.1, "step": 0.01, "fmt": "%.3f"},
86
+ "gamma": {"label": "Polytropic index (gamma)", "min": 1.0, "max": 1.5, "step": 0.01, "fmt": "%.3f"},
87
+ "chi": {"label": "Shell elasticity (chi) [N/m]", "min": 0.0, "max": 1.0, "step": 0.01, "fmt": "%.3f"},
88
+ "mu_L": {"label": "Liquid viscosity (mu_L) [Pa.s]", "min": 0.0001, "max": 0.005, "step": 0.0001, "fmt": "%.5f"},
89
+ "kappa_s": {"label": "Shell viscosity (kappa_s) [kg/s]", "min": 1e-10, "max": 5e-8, "step": 1e-10, "fmt": "%.1e"},
90
+ "rho_L": {"label": "Liquid density (rho_L) [kg/m^3]", "min": 900.0, "max": 1100.0, "step": 5.0, "fmt": "%.1f"},
91
+ "c_L": {"label": "Speed of sound (c_L) [m/s]", "min": 1400.0, "max": 1600.0, "step": 5.0, "fmt": "%.1f"},
92
+ "p_amb": {"label": "Ambient pressure (P_amb) [Pa]", "min": 80000.0, "max": 140000.0, "step": 500.0, "fmt": "%.1f"},
93
+ "sigma_L": {"label": "Surface tension (sigma_L) [N/m]", "min": 0.01, "max": 0.1, "step": 0.001, "fmt": "%.4f"},
94
+ "vdw_divisor": {"label": "Van der Waals divisor", "min": 3.0, "max": 8.0, "step": 0.1, "fmt": "%.2f"},
95
+ }
96
+
97
+ ArrayLike = Union[Sequence[float], np.ndarray]
98
+
99
+
100
+ def render_pulse_preview(shape_name: str, apply_hann: bool, cycles: int = 3) -> go.Figure:
101
+ """Generate a small preview of the pulse shape."""
102
+ shape = SHAPE_MAP[shape_name]
103
+ # Use normalized time with freq=1 for visualization
104
+ freq = 1.0
105
+ initial_time = 0.0
106
+ phase = 0.0
107
+ t = np.linspace(0, cycles, 500)
108
+ # Generate pulse waveform using the shape's __call__ method
109
+ waveform = np.array([float(shape(ti, freq, phase, initial_time)) for ti in t])
110
+
111
+ # Apply Hann window if selected
112
+ if apply_hann:
113
+ hann = np.sin(np.pi * t / cycles) ** 2
114
+ waveform = waveform * hann
115
+
116
+ fig = go.Figure()
117
+ fig.add_trace(go.Scatter(
118
+ x=t,
119
+ y=waveform,
120
+ mode="lines",
121
+ line=dict(color="#45FFE9", width=2),
122
+ showlegend=False,
123
+ ))
124
+ fig.update_layout(
125
+ template="plotly_white",
126
+ height=70,
127
+ margin=dict(l=5, r=5, t=5, b=5),
128
+ xaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
129
+ yaxis=dict(showticklabels=False, showgrid=False, zeroline=True, zerolinecolor="#666", zerolinewidth=1),
130
+ )
131
+ return fig
132
+
133
+
134
+ def init_session_state() -> None:
135
+ if "sweep_store" not in st.session_state:
136
+ st.session_state["sweep_store"] = None
137
+ for k, v in DEFAULTS.items():
138
+ if k not in st.session_state:
139
+ st.session_state[k] = v
140
+
141
+
142
+ def reset_bubble_defaults() -> None:
143
+ bubble_keys = [
144
+ "radius", "r_buckle_fraction", "gamma", "chi", "mu_L",
145
+ "kappa_s", "rho_L", "c_L", "p_amb", "sigma_L", "vdw_divisor",
146
+ ]
147
+ for key in bubble_keys:
148
+ st.session_state[key] = DEFAULTS[key]
149
+
150
+
151
+ @st.cache_resource(show_spinner=False)
152
+ def _jitted_simulator():
153
+ return jax.jit(run_simulation)
154
+
155
+
156
+ def run_single_sim(params: Dict[str, Any]):
157
+ R0 = params["radius"] * 1e-6
158
+ bubble = Bubble(
159
+ R0=R0,
160
+ R_buckle=params["r_buckle_fraction"] * R0,
161
+ gamma=params["gamma"],
162
+ chi=params["chi"],
163
+ mu_L=params["mu_L"],
164
+ kappa_s=params["kappa_s"],
165
+ rho_L=params["rho_L"],
166
+ c_L=params["c_L"],
167
+ P_amb=params["p_amb"],
168
+ sigma_L=params["sigma_L"],
169
+ vdw_divisor=params["vdw_divisor"],
170
+ )
171
+ pulse = Pulse(
172
+ shape=SHAPE_MAP[params["pulse_shape"]],
173
+ freq=params["freq"] * 1e3,
174
+ pressure=params["pressure"] * 1e3,
175
+ cycle_num=int(params["cycles"]),
176
+ initial_time=1e-6,
177
+ apply_hann=bool(params["apply_hann"]),
178
+ )
179
+ result = _jitted_simulator()(
180
+ bubble=bubble,
181
+ pulse=pulse,
182
+ units=UNITS,
183
+ save_spec=SAVE_SPEC,
184
+ )
185
+ arrays = arrays_from_result(result)
186
+ expansion_ratio = float(np.max(arrays.radius_um) / params["radius"])
187
+ return result, arrays, expansion_ratio
188
+
189
+
190
+ def render_axis_controls(axis_key: str, key_prefix: str) -> Tuple[float, float]:
191
+ spec = PARAM_SPECS[axis_key]
192
+ min_v = spec["min"]
193
+ max_v = spec["max"]
194
+ step = spec["step"]
195
+ label = spec["label"]
196
+ range_default = (float(min_v), float(max_v))
197
+ return st.slider(f"{label}", min_value=float(min_v), max_value=float(max_v), value=range_default, step=float(step), key=f"{key_prefix}_{axis_key}")
198
+
199
+
200
+ def sweep_grid(x_axis: str, y_axis: str, x_values: np.ndarray, y_values: np.ndarray, base_params: Dict[str, Any]):
201
+ total = len(x_values) * len(y_values)
202
+ grid = np.zeros((len(y_values), len(x_values)))
203
+ progress = st.progress(0.0)
204
+ status = st.empty()
205
+ count = 0
206
+ for j, y_val in enumerate(y_values):
207
+ for i, x_val in enumerate(x_values):
208
+ params = dict(base_params)
209
+ params[x_axis] = float(x_val)
210
+ params[y_axis] = float(y_val)
211
+ _, _, expansion_ratio = run_single_sim(params)
212
+ grid[j, i] = expansion_ratio
213
+ count += 1
214
+ progress.progress(count / total)
215
+ status.write(f"Solving {count}/{total} sims...")
216
+ status.empty()
217
+ progress.empty()
218
+ return grid
219
+
220
+
221
+ init_session_state()
222
+ st.set_page_config(page_title="jbubble sweep", layout="wide")
223
+
224
+ axis_options = list(PARAM_SPECS.keys())
225
+
226
+
227
+ with st.sidebar:
228
+ st.image("sweep.svg", width='stretch')
229
+
230
+ sweep_button = st.button("Run", type="primary", width='stretch')
231
+
232
+ col_xy_A = st.columns(2)
233
+ with col_xy_A[0]:
234
+ x_axis = st.selectbox("X axis", axis_options, index=axis_options.index("radius"))
235
+ with col_xy_A[1]:
236
+ y_axis = st.selectbox("Y axis", axis_options, index=axis_options.index("freq"))
237
+
238
+ if x_axis == y_axis:
239
+ st.error("Choose two different axes to sweep.")
240
+ x_range = render_axis_controls(x_axis, "x_range")
241
+ y_range = x_range # Placeholder, won't be used
242
+ else:
243
+ x_range = render_axis_controls(x_axis, "x_range")
244
+ y_range = render_axis_controls(y_axis, "y_range")
245
+
246
+ col_xy_B = st.columns(2)
247
+ with col_xy_B[0]:
248
+ x_points = st.slider("X resolution", 5, 200, 50, step=1)
249
+ with col_xy_B[1]:
250
+ y_points = st.slider("Y resolution", 5, 200, 50, step=1)
251
+
252
+ st.markdown("---")
253
+ param_inputs: Dict[str, Any] = {}
254
+
255
+ # Pulse parameters
256
+ pulse_params = ["freq", "pressure", "cycles"]
257
+ with st.expander("Pulse", expanded=False):
258
+ pulse_shape = st.selectbox("Pulse shape", list(SHAPE_MAP.keys()), index=list(SHAPE_MAP.keys()).index(DEFAULTS["pulse_shape"]))
259
+ apply_hann = st.checkbox("Hann window", value=DEFAULTS["apply_hann"])
260
+ param_inputs["pulse_shape"] = pulse_shape
261
+ param_inputs["apply_hann"] = apply_hann
262
+
263
+ # Pulse shape preview
264
+ assert pulse_shape is not None
265
+ st.plotly_chart(render_pulse_preview(pulse_shape, apply_hann, 3), width='stretch', config={"displayModeBar": False})
266
+
267
+ for key in pulse_params:
268
+ if key in (x_axis, y_axis):
269
+ continue
270
+ spec = PARAM_SPECS[key]
271
+ default_val = DEFAULTS[key]
272
+ min_v = spec["min"]
273
+ max_v = spec["max"]
274
+ step = spec["step"]
275
+ if isinstance(default_val, int):
276
+ val = st.slider(spec["label"], int(min_v), int(max_v), int(default_val), step=int(step))
277
+ else:
278
+ val = st.slider(spec["label"], float(min_v), float(max_v), float(default_val), step=float(step))
279
+ param_inputs[key] = val
280
+
281
+ # Bubble parameters
282
+ with st.expander("Bubble", expanded=False):
283
+ st.button("Reset to Defaults", on_click=reset_bubble_defaults)
284
+ param_inputs["radius"] = st.number_input("Equilibrium radius (μm)", format="%.2f", step=0.1, key="radius", disabled="radius" in (x_axis, y_axis))
285
+ param_inputs["r_buckle_fraction"] = st.number_input("R_buckle fraction", format="%.4f", step=0.01, key="r_buckle_fraction", disabled="r_buckle_fraction" in (x_axis, y_axis))
286
+ param_inputs["gamma"] = st.number_input("Polytropic index (gamma)", format="%.4f", step=0.01, key="gamma", disabled="gamma" in (x_axis, y_axis))
287
+ param_inputs["chi"] = st.number_input("Shell elasticity (chi) [N/m]", format="%.4f", step=0.01, key="chi", disabled="chi" in (x_axis, y_axis))
288
+ param_inputs["mu_L"] = st.number_input("Liquid viscosity (mu_L) [Pa.s]", format="%.6f", step=0.00001, key="mu_L", disabled="mu_L" in (x_axis, y_axis))
289
+ param_inputs["kappa_s"] = st.number_input("Shell viscosity (kappa_s) [kg/s]", format="%.3e", step=1e-10, key="kappa_s", disabled="kappa_s" in (x_axis, y_axis))
290
+ param_inputs["rho_L"] = st.number_input("Liquid density (rho_L) [kg/m^3]", format="%.1f", step=10.0, key="rho_L", disabled="rho_L" in (x_axis, y_axis))
291
+ param_inputs["c_L"] = st.number_input("Speed of sound (c_L) [m/s]", format="%.1f", step=10.0, key="c_L", disabled="c_L" in (x_axis, y_axis))
292
+ param_inputs["p_amb"] = st.number_input("Ambient pressure (P_amb) [Pa]", format="%.1f", step=100.0, key="p_amb", disabled="p_amb" in (x_axis, y_axis))
293
+ param_inputs["sigma_L"] = st.number_input("Surface tension (sigma_L) [N/m]", format="%.4f", step=0.001, key="sigma_L", disabled="sigma_L" in (x_axis, y_axis))
294
+ param_inputs["vdw_divisor"] = st.number_input("Van der Waals divisor", format="%.2f", step=0.1, key="vdw_divisor", disabled="vdw_divisor" in (x_axis, y_axis))
295
+
296
+ st.markdown("---")
297
+ invert_cols = st.columns(2)
298
+ with invert_cols[0]:
299
+ invert_x_axis = st.checkbox("Invert X", value=False)
300
+ with invert_cols[1]:
301
+ invert_y_axis = st.checkbox("Invert Y", value=True)
302
+
303
+
304
+ # with plot_col:
305
+ placeholder_heatmap = st.empty()
306
+
307
+ if sweep_button and x_axis != y_axis:
308
+ x_values = np.linspace(x_range[0], x_range[1], num=x_points)
309
+ y_values = np.linspace(y_range[0], y_range[1], num=y_points)
310
+ base_params = dict(param_inputs)
311
+ base_params[x_axis] = DEFAULTS.get(x_axis, x_range[0])
312
+ base_params[y_axis] = DEFAULTS.get(y_axis, y_range[0])
313
+ grid = sweep_grid(x_axis, y_axis, x_values, y_values, base_params)
314
+ st.session_state["sweep_store"] = {
315
+ "x_axis": x_axis,
316
+ "y_axis": y_axis,
317
+ "x_values": x_values,
318
+ "y_values": y_values,
319
+ "x_range": x_range,
320
+ "y_range": y_range,
321
+ "x_points": x_points,
322
+ "y_points": y_points,
323
+ "grid": grid,
324
+ "base_params": base_params,
325
+ }
326
+
327
+ store = st.session_state.get("sweep_store")
328
+ if store:
329
+ # Check if current settings match stored settings
330
+ current_params = dict(param_inputs)
331
+ current_params[x_axis] = DEFAULTS.get(x_axis, x_range[0])
332
+ current_params[y_axis] = DEFAULTS.get(y_axis, y_range[0])
333
+
334
+ is_stale = (
335
+ store["x_axis"] != x_axis or
336
+ store["y_axis"] != y_axis or
337
+ store["x_range"] != x_range or
338
+ store["y_range"] != y_range or
339
+ store["x_points"] != x_points or
340
+ store["y_points"] != y_points or
341
+ store["base_params"] != current_params
342
+ )
343
+
344
+ # Use grayscale colorscale when stale
345
+ colorscale = "Viridis"
346
+
347
+ fig = go.Figure(
348
+ data=go.Heatmap(
349
+ x=store["x_values"],
350
+ y=store["y_values"],
351
+ z=store["grid"],
352
+ colorscale=colorscale,
353
+ colorbar=dict(title="Max expansion Rmax/R0"),
354
+ )
355
+ )
356
+ fig.update_layout(
357
+ template="plotly_white",
358
+ height=700,
359
+ width=600,
360
+ margin=dict(l=60, r=10, t=30, b=40),
361
+ xaxis_title=PARAM_SPECS[store["x_axis"]]["label"],
362
+ yaxis_title=PARAM_SPECS[store["y_axis"]]["label"],
363
+ )
364
+ if invert_x_axis:
365
+ fig.update_xaxes(autorange="reversed")
366
+ if invert_y_axis:
367
+ fig.update_yaxes(autorange="reversed")
368
+
369
+ placeholder_heatmap.plotly_chart(fig, width='stretch')
370
+
371
+ if is_stale:
372
+ st.warning("⚠️ Parameters changed — press **Run** to update")
373
+
374
+ elif not store:
375
+ placeholder_heatmap.info("Configure axes and press Run to compute the heatmap.")