quantum / em /handlers.py
harishaseebat92
Set Optimization Level = 1
e507df4
"""State change handlers for the EM module.
All @state.change decorated functions that respond to UI state changes.
"""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from .state import state, ctrl, is_statevector_estimator_selected
from .geometry import compute_hole_edges as _compute_hole_edges, update_geometry_hole_preview
from .excitation import update_initial_state_preview, update_excitation_info_message
from .qpu import (
update_qpu_sample_slot as _update_qpu_sample_slot,
refresh_all_qpu_sample_slots as _refresh_all_qpu_sample_slots,
hide_qpu_plots as _hide_qpu_plots,
)
from .simulation import update_sim_monitor_points as _update_sim_monitor_points
if TYPE_CHECKING:
pass
# ---------------------------------------------------------------------------
# Workflow Highlights (visual feedback for UI cards)
# ---------------------------------------------------------------------------
def _determine_workflow_step() -> int:
"""Determine which configuration step the user is on (1-6)."""
try:
if not state.problem_selection:
return 1
if not state.geometry_selection:
return 2
if not state.dist_type:
return 3
nx = state.nx
if nx is None:
return 4
if not state.backend_type:
return 5
return 6
except Exception:
return 1
def _apply_workflow_highlights(step: int):
"""Set card styles based on current workflow step."""
try:
base_style = "font-size: 0.8rem;"
highlight = "font-size: 0.8rem; border: 2px solid #5F259F; box-shadow: 0 0 8px rgba(95,37,159,0.25);"
state.overview_card_style = highlight if step == 1 else base_style
state.geometry_card_style = highlight if step == 2 else base_style
state.excitation_card_style = highlight if step == 3 else base_style
state.meshing_card_style = highlight if step == 4 else base_style
state.backend_card_style = highlight if step == 5 else base_style
state.output_card_style = highlight if step == 6 else base_style
except Exception:
pass
# ---------------------------------------------------------------------------
# Qubit Plot (meshing slider)
# ---------------------------------------------------------------------------
def build_qubit_plot(grid_size: int):
"""Build a Plotly figure showing qubit requirements vs grid size."""
import numpy as np
import plotly.graph_objects as go
x_sizes = np.array([16, 32, 64, 128, 256, 512])
y_qubits = 2 * np.ceil(np.log2(x_sizes)).astype(int) + 4
current_nq = int(2 * np.ceil(np.log2(max(1, int(grid_size)))) + 4)
fig = go.Figure()
fig.add_trace(go.Scatter(x=x_sizes, y=y_qubits, mode='lines', name='Total Qubits', line=dict(color='#7A3DB5', width=3)))
fig.add_trace(go.Scatter(x=[grid_size], y=[current_nq], mode='markers', marker=dict(size=10, color='#5F259F'), name='Current Selection'))
x_min = int(x_sizes.min())
x_max = int(x_sizes.max())
y_min = int(y_qubits.min())
y_max = int(max(y_qubits.max(), current_nq))
fig.update_xaxes(
range=[x_min - 8, x_max + 8],
tickmode='array',
tickvals=x_sizes,
ticktext=[str(v) for v in x_sizes],
title_text="Grid Size (nx)",
gridcolor='rgba(95,37,159,0.1)',
zerolinecolor='rgba(95,37,159,0.3)'
)
fig.update_yaxes(
range=[y_min - 1, y_max + 1],
dtick=1,
title_text="Total Qubits (nq)",
gridcolor='rgba(95,37,159,0.1)',
zerolinecolor='rgba(95,37,159,0.3)'
)
fig.update_layout(
margin=dict(l=30, r=10, t=10, b=30),
autosize=True,
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
font=dict(color='#1A1A1A'),
paper_bgcolor='#FFFFFF',
plot_bgcolor='#FFFFFF',
colorway=['#5F259F', '#7A3DB5', '#AE8BD8', '#5F259F'],
)
return fig
# ---------------------------------------------------------------------------
# State Change Handlers - Registered after server binding
# ---------------------------------------------------------------------------
def register_handlers():
"""Register all @state.change handlers. Call after server is bound."""
if not state.bound:
return
@state.change("nx")
def update_qubit_plot_handler(nx, **kwargs):
try:
ctrl.qubit_plot_update(build_qubit_plot(int(nx)))
except Exception:
pass
@state.change("hole_size_edge", "hole_center_x", "hole_center_y", "geometry_selection", "hole_snap")
def validate_hole_inputs(**kwargs):
# Only validate when Square Metallic Body is selected
if state.geometry_selection != "Square Metallic Body":
state.hole_error_message = ""
return
try:
s = float(state.hole_size_edge)
cx = float(state.hole_center_x)
cy = float(state.hole_center_y)
except Exception:
state.hole_error_message = "Hole size and center must be numeric."
return
# Use selected nx, fall back to a safe default
try:
nx = int(state.nx or 32)
except Exception:
nx = 32
if s > 1.0:
state.hole_error_message = "Hole edge length must be <= 1."
return
if not (0.0 < cx < 1.0) or not (0.0 < cy < 1.0):
state.hole_error_message = "Hole center must be strictly within (0, 1) for both X and Y."
return
# Alignment check (strict vs snap)
mode_snap = bool(state.hole_snap)
edges = _compute_hole_edges(nx, cx, cy, s, snap=mode_snap)
if not mode_snap and edges is None:
state.hole_error_message = "Hole edges must align with grid lines; enable Snap to auto-align."
return
# Inputs valid; clear error and refresh preview
state.hole_error_message = ""
update_geometry_hole_preview()
@state.change("hole_center_pair")
def sync_hole_center_pair(hole_center_pair, **kwargs):
"""Parse bracket-format pair (x, y) from dropdown into numeric center fields."""
try:
m = re.match(r"\(\s*([-+]?[0-9]*\.?[0-9]+)\s*,\s*([-+]?[0-9]*\.?[0-9]+)\s*\)", str(hole_center_pair))
if not m:
raise ValueError("Invalid format")
state.hole_center_x = float(m.group(1))
state.hole_center_y = float(m.group(2))
state.hole_error_message = ""
except Exception:
state.hole_error_message = "Invalid hole center. Use format (x, y)."
@state.change("sigma_pair")
def sync_sigma_pair(sigma_pair, **kwargs):
"""Parse bracket-format pair (x, y) for Sigma and update sigma_x/sigma_y."""
try:
m = re.match(r"\(\s*([-+]?[0-9]*\.?[0-9]+)\s*,\s*([-+]?[0-9]*\.?[0-9]+)\s*\)", str(sigma_pair))
if not m:
raise ValueError("Invalid format")
x = max(0.0, min(1.0, float(m.group(1))))
y = max(0.0, min(1.0, float(m.group(2))))
state.sigma_x = x
state.sigma_y = y
state.excitation_error_message = ""
except Exception:
state.excitation_error_message = "Invalid Sigma. Use format (x, y) in [0,1]."
@state.change("dist_type")
def normalize_dist_type(dist_type, **kwargs):
# Allow unselecting via 'None'
if dist_type in (None, "", "None"):
state.dist_type = None
update_initial_state_preview()
_apply_workflow_highlights(_determine_workflow_step())
return
update_excitation_info_message()
_apply_workflow_highlights(_determine_workflow_step())
@state.change("qpu_monitor_samples")
def on_qpu_monitor_samples_change(**kwargs):
_update_qpu_sample_slot(1)
@state.change("qpu_monitor_samples_2")
def on_qpu_monitor_samples_2_change(**kwargs):
_update_qpu_sample_slot(2)
@state.change("qpu_monitor_samples_3")
def on_qpu_monitor_samples_3_change(**kwargs):
_update_qpu_sample_slot(3)
@state.change("qpu_monitor_samples_4")
def on_qpu_monitor_samples_4_change(**kwargs):
_update_qpu_sample_slot(4)
@state.change("qpu_monitor_samples_5")
def on_qpu_monitor_samples_5_change(**kwargs):
_update_qpu_sample_slot(5)
@state.change("nx")
def on_nx_change_refresh_qpu_samples(nx, **kwargs):
_refresh_all_qpu_sample_slots()
_update_sim_monitor_points()
_apply_workflow_highlights(_determine_workflow_step())
@state.change("dt_user")
def validate_dt_user(dt_user, **kwargs):
"""Validate snapshot Δt: must be >= 0.1 (solver dt) and a multiple of 0.1."""
try:
dt_val = float(dt_user)
except Exception:
state.temporal_warning = "Δt must be numeric. Frames are captured every Δt."
return
tol = 1e-9
if dt_val < 0.1 - tol:
state.temporal_warning = "Δt < 0.1 is unsupported (solver dt = 0.1 s)."
elif abs((dt_val / 0.1) - round(dt_val / 0.1)) > 1e-9:
state.temporal_warning = "Δt must be a multiple of 0.1 s."
else:
state.temporal_warning = ""
@state.change("backend_type")
def on_backend_change(backend_type, **kwargs):
if backend_type == "QPU" or (backend_type == "Simulator" and is_statevector_estimator_selected()):
_hide_qpu_plots()
_apply_workflow_highlights(_determine_workflow_step())
@state.change("selected_qpu")
def on_selected_qpu_change(selected_qpu, **kwargs):
if state.backend_type == "QPU":
_hide_qpu_plots()
@state.change("selected_simulator")
def on_selected_simulator_change(selected_simulator, **kwargs):
if is_statevector_estimator_selected():
_hide_qpu_plots()
@state.change("qpu_plot_filter")
def on_qpu_plot_filter_change(qpu_plot_filter, **kwargs):
# No-op: updates handled by controller bound to the VSelect to avoid double refresh
return
@state.change("problem_selection")
def on_problem_change(problem_selection, **kwargs):
"""Update geometry options and auto-select based on problem selection."""
from .simulation import log_to_console
if problem_selection == "Propagation in a given medium (no bodies)":
# Only show "Square Domain" for propagation problem
state.geometry_options = ["Square Domain"]
state.geometry_selection = "Square Domain"
log_to_console("Auto-selected 'Square Domain' geometry for propagation problem.")
elif problem_selection == "Scattering from a perfectly conducting body":
# Only show "Square Metallic Body" for scattering problem
state.geometry_options = ["Square Metallic Body"]
state.geometry_selection = "Square Metallic Body"
log_to_console("Auto-selected 'Square Metallic Body' geometry for scattering problem.")
else:
# Show both options when no specific problem selected
state.geometry_options = ["Square Domain", "Square Metallic Body"]
state.geometry_selection = None
_apply_workflow_highlights(_determine_workflow_step())
@state.change("geometry_selection")
def on_geometry_change(geometry_selection, **kwargs):
_apply_workflow_highlights(_determine_workflow_step())
@state.change("peak_pair")
def sync_peak_pair(peak_pair, **kwargs):
"""Parse peak pair (x, y) and validate."""
try:
m = re.match(r"\(\s*([-+]?[0-9]*\.?[0-9]+)\s*,\s*([-+]?[0-9]*\.?[0-9]+)\s*\)", str(peak_pair))
if not m:
raise ValueError("Invalid format")
x = float(m.group(1))
y = float(m.group(2))
if not (0.0 <= x <= 1.0) or not (0.0 <= y <= 1.0):
state.excitation_error_message = "Peak must be in [0,1]."
return
state.impulse_x = x
state.impulse_y = y
state.excitation_error_message = ""
update_initial_state_preview()
except Exception:
state.excitation_error_message = "Invalid peak. Use format (x, y)."
@state.change("mu_pair")
def sync_mu_pair(mu_pair, **kwargs):
"""Parse mu pair (x, y) and validate."""
try:
m = re.match(r"\(\s*([-+]?[0-9]*\.?[0-9]+)\s*,\s*([-+]?[0-9]*\.?[0-9]+)\s*\)", str(mu_pair))
if not m:
raise ValueError("Invalid format")
x = float(m.group(1))
y = float(m.group(2))
if not (0.0 <= x <= 1.0) or not (0.0 <= y <= 1.0):
state.excitation_error_message = "Mu must be in [0,1]."
return
state.mu_x = x
state.mu_y = y
state.excitation_error_message = ""
update_initial_state_preview()
except Exception:
state.excitation_error_message = "Invalid mu. Use format (x, y)."
@state.change("sigma_x", "sigma_y")
def on_sigma_change(sigma_x, sigma_y, **kwargs):
update_initial_state_preview()
# -----------------------------------------------------------------------
# Additional handlers for simulation workflow
# -----------------------------------------------------------------------
@state.change("nx_slider_index")
def on_slider_index_change(nx_slider_index, **kwargs):
"""Handle grid size slider changes."""
from .globals import GRID_SIZES
if nx_slider_index is None:
state.nx = None
else:
try:
state.nx = int(GRID_SIZES[int(nx_slider_index)])
except Exception:
state.nx = None
update_excitation_info_message()
@state.change("nx", "T", "dist_type", "impulse_x", "impulse_y", "mu_x", "mu_y", "sigma_x", "sigma_y", "coeff_permittivity", "coeff_permeability")
def on_input_parameter_change(**kwargs):
"""Handle changes to input parameters."""
from .simulation import generate_plot
if state.is_running:
return
update_excitation_info_message()
if state.backend_type == "QPU" or is_statevector_estimator_selected():
state.qpu_ts_ready = False
state.qpu_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;"
state.qpu_ts_other_ready = False
state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;"
changed_keys = set(kwargs.keys())
if state.simulation_has_run:
state.run_button_text = "Re-run!"
return
preview_params = {"nx", "dist_type", "impulse_x", "impulse_y", "mu_x", "mu_y", "sigma_x", "sigma_y"}
if changed_keys & preview_params:
update_initial_state_preview()
@state.change("output_type", "timeseries_field", "timeseries_points")
def on_output_config_change(**kwargs):
"""Handle changes to output configuration."""
from .simulation import generate_plot
_update_sim_monitor_points()
if state.simulation_has_run:
generate_plot()
@state.change("timeseries_points")
def on_timeseries_points_text_change(**kwargs):
"""Handle changes to timeseries points text."""
_update_sim_monitor_points()
@state.change("surface_field")
def on_surface_field_change(surface_field, **kwargs):
"""Handle changes to surface field selection."""
from .simulation import redraw_surface_plot
if state.simulation_has_run and state.output_type == "Surface Plot":
redraw_surface_plot()
@state.change("time_val")
def on_time_change(time_val, **kwargs):
"""Handle changes to time value."""
from .simulation import redraw_surface_plot
if not state.simulation_has_run or state.output_type != "Surface Plot":
return
redraw_surface_plot()
@state.change("geometry_selection")
def handle_geometry_add(geometry_selection, **kwargs):
"""Handle geometry selection changes."""
if geometry_selection in (None, "", "None"):
state.geometry_selection = None
update_initial_state_preview()
_apply_workflow_highlights(_determine_workflow_step())
return
if geometry_selection == "Add":
state.show_upload_dialog = True
state.geometry_selection = None
_apply_workflow_highlights(_determine_workflow_step())
return
update_initial_state_preview()
_apply_workflow_highlights(_determine_workflow_step())
@state.change("uploaded_file_info")
def handle_file_upload(uploaded_file_info, **kwargs):
"""Handle file upload completion."""
if uploaded_file_info:
file_name = uploaded_file_info.get("name", "unknown file")
print(f"File selected (dummy upload): {file_name}")
state.show_upload_dialog = False
state.upload_status_message = f"File '{file_name}' uploaded."
state.show_upload_status = True