Spaces:
Paused
Paused
| """State change handlers for the EM module. | |
| All @state.change decorated functions that respond to UI state changes. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from typing import TYPE_CHECKING | |
| from .state import state, ctrl, is_statevector_estimator_selected | |
| from .geometry import compute_hole_edges as _compute_hole_edges, update_geometry_hole_preview | |
| from .excitation import update_initial_state_preview, update_excitation_info_message | |
| from .qpu import ( | |
| update_qpu_sample_slot as _update_qpu_sample_slot, | |
| refresh_all_qpu_sample_slots as _refresh_all_qpu_sample_slots, | |
| hide_qpu_plots as _hide_qpu_plots, | |
| ) | |
| from .simulation import update_sim_monitor_points as _update_sim_monitor_points | |
| if TYPE_CHECKING: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Workflow Highlights (visual feedback for UI cards) | |
| # --------------------------------------------------------------------------- | |
| def _determine_workflow_step() -> int: | |
| """Determine which configuration step the user is on (1-6).""" | |
| try: | |
| if not state.problem_selection: | |
| return 1 | |
| if not state.geometry_selection: | |
| return 2 | |
| if not state.dist_type: | |
| return 3 | |
| nx = state.nx | |
| if nx is None: | |
| return 4 | |
| if not state.backend_type: | |
| return 5 | |
| return 6 | |
| except Exception: | |
| return 1 | |
| def _apply_workflow_highlights(step: int): | |
| """Set card styles based on current workflow step.""" | |
| try: | |
| base_style = "font-size: 0.8rem;" | |
| highlight = "font-size: 0.8rem; border: 2px solid #5F259F; box-shadow: 0 0 8px rgba(95,37,159,0.25);" | |
| state.overview_card_style = highlight if step == 1 else base_style | |
| state.geometry_card_style = highlight if step == 2 else base_style | |
| state.excitation_card_style = highlight if step == 3 else base_style | |
| state.meshing_card_style = highlight if step == 4 else base_style | |
| state.backend_card_style = highlight if step == 5 else base_style | |
| state.output_card_style = highlight if step == 6 else base_style | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Qubit Plot (meshing slider) | |
| # --------------------------------------------------------------------------- | |
| def build_qubit_plot(grid_size: int): | |
| """Build a Plotly figure showing qubit requirements vs grid size.""" | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| x_sizes = np.array([16, 32, 64, 128, 256, 512]) | |
| y_qubits = 2 * np.ceil(np.log2(x_sizes)).astype(int) + 4 | |
| current_nq = int(2 * np.ceil(np.log2(max(1, int(grid_size)))) + 4) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter(x=x_sizes, y=y_qubits, mode='lines', name='Total Qubits', line=dict(color='#7A3DB5', width=3))) | |
| fig.add_trace(go.Scatter(x=[grid_size], y=[current_nq], mode='markers', marker=dict(size=10, color='#5F259F'), name='Current Selection')) | |
| x_min = int(x_sizes.min()) | |
| x_max = int(x_sizes.max()) | |
| y_min = int(y_qubits.min()) | |
| y_max = int(max(y_qubits.max(), current_nq)) | |
| fig.update_xaxes( | |
| range=[x_min - 8, x_max + 8], | |
| tickmode='array', | |
| tickvals=x_sizes, | |
| ticktext=[str(v) for v in x_sizes], | |
| title_text="Grid Size (nx)", | |
| gridcolor='rgba(95,37,159,0.1)', | |
| zerolinecolor='rgba(95,37,159,0.3)' | |
| ) | |
| fig.update_yaxes( | |
| range=[y_min - 1, y_max + 1], | |
| dtick=1, | |
| title_text="Total Qubits (nq)", | |
| gridcolor='rgba(95,37,159,0.1)', | |
| zerolinecolor='rgba(95,37,159,0.3)' | |
| ) | |
| fig.update_layout( | |
| margin=dict(l=30, r=10, t=10, b=30), | |
| autosize=True, | |
| legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1), | |
| font=dict(color='#1A1A1A'), | |
| paper_bgcolor='#FFFFFF', | |
| plot_bgcolor='#FFFFFF', | |
| colorway=['#5F259F', '#7A3DB5', '#AE8BD8', '#5F259F'], | |
| ) | |
| return fig | |
| # --------------------------------------------------------------------------- | |
| # State Change Handlers - Registered after server binding | |
| # --------------------------------------------------------------------------- | |
| def register_handlers(): | |
| """Register all @state.change handlers. Call after server is bound.""" | |
| if not state.bound: | |
| return | |
| def update_qubit_plot_handler(nx, **kwargs): | |
| try: | |
| ctrl.qubit_plot_update(build_qubit_plot(int(nx))) | |
| except Exception: | |
| pass | |
| def validate_hole_inputs(**kwargs): | |
| # Only validate when Square Metallic Body is selected | |
| if state.geometry_selection != "Square Metallic Body": | |
| state.hole_error_message = "" | |
| return | |
| try: | |
| s = float(state.hole_size_edge) | |
| cx = float(state.hole_center_x) | |
| cy = float(state.hole_center_y) | |
| except Exception: | |
| state.hole_error_message = "Hole size and center must be numeric." | |
| return | |
| # Use selected nx, fall back to a safe default | |
| try: | |
| nx = int(state.nx or 32) | |
| except Exception: | |
| nx = 32 | |
| if s > 1.0: | |
| state.hole_error_message = "Hole edge length must be <= 1." | |
| return | |
| if not (0.0 < cx < 1.0) or not (0.0 < cy < 1.0): | |
| state.hole_error_message = "Hole center must be strictly within (0, 1) for both X and Y." | |
| return | |
| # Alignment check (strict vs snap) | |
| mode_snap = bool(state.hole_snap) | |
| edges = _compute_hole_edges(nx, cx, cy, s, snap=mode_snap) | |
| if not mode_snap and edges is None: | |
| state.hole_error_message = "Hole edges must align with grid lines; enable Snap to auto-align." | |
| return | |
| # Inputs valid; clear error and refresh preview | |
| state.hole_error_message = "" | |
| update_geometry_hole_preview() | |
| def sync_hole_center_pair(hole_center_pair, **kwargs): | |
| """Parse bracket-format pair (x, y) from dropdown into numeric center fields.""" | |
| try: | |
| m = re.match(r"\(\s*([-+]?[0-9]*\.?[0-9]+)\s*,\s*([-+]?[0-9]*\.?[0-9]+)\s*\)", str(hole_center_pair)) | |
| if not m: | |
| raise ValueError("Invalid format") | |
| state.hole_center_x = float(m.group(1)) | |
| state.hole_center_y = float(m.group(2)) | |
| state.hole_error_message = "" | |
| except Exception: | |
| state.hole_error_message = "Invalid hole center. Use format (x, y)." | |
| def sync_sigma_pair(sigma_pair, **kwargs): | |
| """Parse bracket-format pair (x, y) for Sigma and update sigma_x/sigma_y.""" | |
| try: | |
| m = re.match(r"\(\s*([-+]?[0-9]*\.?[0-9]+)\s*,\s*([-+]?[0-9]*\.?[0-9]+)\s*\)", str(sigma_pair)) | |
| if not m: | |
| raise ValueError("Invalid format") | |
| x = max(0.0, min(1.0, float(m.group(1)))) | |
| y = max(0.0, min(1.0, float(m.group(2)))) | |
| state.sigma_x = x | |
| state.sigma_y = y | |
| state.excitation_error_message = "" | |
| except Exception: | |
| state.excitation_error_message = "Invalid Sigma. Use format (x, y) in [0,1]." | |
| def normalize_dist_type(dist_type, **kwargs): | |
| # Allow unselecting via 'None' | |
| if dist_type in (None, "", "None"): | |
| state.dist_type = None | |
| update_initial_state_preview() | |
| _apply_workflow_highlights(_determine_workflow_step()) | |
| return | |
| update_excitation_info_message() | |
| _apply_workflow_highlights(_determine_workflow_step()) | |
| def on_qpu_monitor_samples_change(**kwargs): | |
| _update_qpu_sample_slot(1) | |
| def on_qpu_monitor_samples_2_change(**kwargs): | |
| _update_qpu_sample_slot(2) | |
| def on_qpu_monitor_samples_3_change(**kwargs): | |
| _update_qpu_sample_slot(3) | |
| def on_qpu_monitor_samples_4_change(**kwargs): | |
| _update_qpu_sample_slot(4) | |
| def on_qpu_monitor_samples_5_change(**kwargs): | |
| _update_qpu_sample_slot(5) | |
| def on_nx_change_refresh_qpu_samples(nx, **kwargs): | |
| _refresh_all_qpu_sample_slots() | |
| _update_sim_monitor_points() | |
| _apply_workflow_highlights(_determine_workflow_step()) | |
| def validate_dt_user(dt_user, **kwargs): | |
| """Validate snapshot Δt: must be >= 0.1 (solver dt) and a multiple of 0.1.""" | |
| try: | |
| dt_val = float(dt_user) | |
| except Exception: | |
| state.temporal_warning = "Δt must be numeric. Frames are captured every Δt." | |
| return | |
| tol = 1e-9 | |
| if dt_val < 0.1 - tol: | |
| state.temporal_warning = "Δt < 0.1 is unsupported (solver dt = 0.1 s)." | |
| elif abs((dt_val / 0.1) - round(dt_val / 0.1)) > 1e-9: | |
| state.temporal_warning = "Δt must be a multiple of 0.1 s." | |
| else: | |
| state.temporal_warning = "" | |
| def on_backend_change(backend_type, **kwargs): | |
| if backend_type == "QPU" or (backend_type == "Simulator" and is_statevector_estimator_selected()): | |
| _hide_qpu_plots() | |
| _apply_workflow_highlights(_determine_workflow_step()) | |
| def on_selected_qpu_change(selected_qpu, **kwargs): | |
| if state.backend_type == "QPU": | |
| _hide_qpu_plots() | |
| def on_selected_simulator_change(selected_simulator, **kwargs): | |
| if is_statevector_estimator_selected(): | |
| _hide_qpu_plots() | |
| def on_qpu_plot_filter_change(qpu_plot_filter, **kwargs): | |
| # No-op: updates handled by controller bound to the VSelect to avoid double refresh | |
| return | |
| def on_problem_change(problem_selection, **kwargs): | |
| """Update geometry options and auto-select based on problem selection.""" | |
| from .simulation import log_to_console | |
| if problem_selection == "Propagation in a given medium (no bodies)": | |
| # Only show "Square Domain" for propagation problem | |
| state.geometry_options = ["Square Domain"] | |
| state.geometry_selection = "Square Domain" | |
| log_to_console("Auto-selected 'Square Domain' geometry for propagation problem.") | |
| elif problem_selection == "Scattering from a perfectly conducting body": | |
| # Only show "Square Metallic Body" for scattering problem | |
| state.geometry_options = ["Square Metallic Body"] | |
| state.geometry_selection = "Square Metallic Body" | |
| log_to_console("Auto-selected 'Square Metallic Body' geometry for scattering problem.") | |
| else: | |
| # Show both options when no specific problem selected | |
| state.geometry_options = ["Square Domain", "Square Metallic Body"] | |
| state.geometry_selection = None | |
| _apply_workflow_highlights(_determine_workflow_step()) | |
| def on_geometry_change(geometry_selection, **kwargs): | |
| _apply_workflow_highlights(_determine_workflow_step()) | |
| def sync_peak_pair(peak_pair, **kwargs): | |
| """Parse peak pair (x, y) and validate.""" | |
| try: | |
| m = re.match(r"\(\s*([-+]?[0-9]*\.?[0-9]+)\s*,\s*([-+]?[0-9]*\.?[0-9]+)\s*\)", str(peak_pair)) | |
| if not m: | |
| raise ValueError("Invalid format") | |
| x = float(m.group(1)) | |
| y = float(m.group(2)) | |
| if not (0.0 <= x <= 1.0) or not (0.0 <= y <= 1.0): | |
| state.excitation_error_message = "Peak must be in [0,1]." | |
| return | |
| state.impulse_x = x | |
| state.impulse_y = y | |
| state.excitation_error_message = "" | |
| update_initial_state_preview() | |
| except Exception: | |
| state.excitation_error_message = "Invalid peak. Use format (x, y)." | |
| def sync_mu_pair(mu_pair, **kwargs): | |
| """Parse mu pair (x, y) and validate.""" | |
| try: | |
| m = re.match(r"\(\s*([-+]?[0-9]*\.?[0-9]+)\s*,\s*([-+]?[0-9]*\.?[0-9]+)\s*\)", str(mu_pair)) | |
| if not m: | |
| raise ValueError("Invalid format") | |
| x = float(m.group(1)) | |
| y = float(m.group(2)) | |
| if not (0.0 <= x <= 1.0) or not (0.0 <= y <= 1.0): | |
| state.excitation_error_message = "Mu must be in [0,1]." | |
| return | |
| state.mu_x = x | |
| state.mu_y = y | |
| state.excitation_error_message = "" | |
| update_initial_state_preview() | |
| except Exception: | |
| state.excitation_error_message = "Invalid mu. Use format (x, y)." | |
| def on_sigma_change(sigma_x, sigma_y, **kwargs): | |
| update_initial_state_preview() | |
| # ----------------------------------------------------------------------- | |
| # Additional handlers for simulation workflow | |
| # ----------------------------------------------------------------------- | |
| def on_slider_index_change(nx_slider_index, **kwargs): | |
| """Handle grid size slider changes.""" | |
| from .globals import GRID_SIZES | |
| if nx_slider_index is None: | |
| state.nx = None | |
| else: | |
| try: | |
| state.nx = int(GRID_SIZES[int(nx_slider_index)]) | |
| except Exception: | |
| state.nx = None | |
| update_excitation_info_message() | |
| def on_input_parameter_change(**kwargs): | |
| """Handle changes to input parameters.""" | |
| from .simulation import generate_plot | |
| if state.is_running: | |
| return | |
| update_excitation_info_message() | |
| if state.backend_type == "QPU" or is_statevector_estimator_selected(): | |
| state.qpu_ts_ready = False | |
| state.qpu_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" | |
| state.qpu_ts_other_ready = False | |
| state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" | |
| changed_keys = set(kwargs.keys()) | |
| if state.simulation_has_run: | |
| state.run_button_text = "Re-run!" | |
| return | |
| preview_params = {"nx", "dist_type", "impulse_x", "impulse_y", "mu_x", "mu_y", "sigma_x", "sigma_y"} | |
| if changed_keys & preview_params: | |
| update_initial_state_preview() | |
| def on_output_config_change(**kwargs): | |
| """Handle changes to output configuration.""" | |
| from .simulation import generate_plot | |
| _update_sim_monitor_points() | |
| if state.simulation_has_run: | |
| generate_plot() | |
| def on_timeseries_points_text_change(**kwargs): | |
| """Handle changes to timeseries points text.""" | |
| _update_sim_monitor_points() | |
| def on_surface_field_change(surface_field, **kwargs): | |
| """Handle changes to surface field selection.""" | |
| from .simulation import redraw_surface_plot | |
| if state.simulation_has_run and state.output_type == "Surface Plot": | |
| redraw_surface_plot() | |
| def on_time_change(time_val, **kwargs): | |
| """Handle changes to time value.""" | |
| from .simulation import redraw_surface_plot | |
| if not state.simulation_has_run or state.output_type != "Surface Plot": | |
| return | |
| redraw_surface_plot() | |
| def handle_geometry_add(geometry_selection, **kwargs): | |
| """Handle geometry selection changes.""" | |
| if geometry_selection in (None, "", "None"): | |
| state.geometry_selection = None | |
| update_initial_state_preview() | |
| _apply_workflow_highlights(_determine_workflow_step()) | |
| return | |
| if geometry_selection == "Add": | |
| state.show_upload_dialog = True | |
| state.geometry_selection = None | |
| _apply_workflow_highlights(_determine_workflow_step()) | |
| return | |
| update_initial_state_preview() | |
| _apply_workflow_highlights(_determine_workflow_step()) | |
| def handle_file_upload(uploaded_file_info, **kwargs): | |
| """Handle file upload completion.""" | |
| if uploaded_file_info: | |
| file_name = uploaded_file_info.get("name", "unknown file") | |
| print(f"File selected (dummy upload): {file_name}") | |
| state.show_upload_dialog = False | |
| state.upload_status_message = f"File '{file_name}' uploaded." | |
| state.show_upload_status = True | |